From 6077feade5f2689cf76a62a17e331956576d4c92 Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 22 Feb 2026 23:59:15 +0000 Subject: [PATCH 01/24] Add int8 quantization for vortex. Key changes: 1. Memory Pool (`vtx_graph_memory_pool.py`): - Removed hardcoded bf16 assertions in `VTXGraphCachePool` to support `torch.int8` allocations. - Added parallel `float32` scale buffers (`k_scale`, `v_scale`) mapped to the paged layout. - Preserved `bfloat16` shadow buffers (`k_bf16`) for auxiliary metadata (e.g., centroids) to ensure the Vortex sparse indexer/TopK remains unaffected and mathematically identical. 2. Quantize-on-Write (`set_kv.py`): - Implemented a custom Triton kernel (`set_kv_buffer_int8_kernel`) that quantizes incoming `bf16` tokens into `int8` on the fly using per-token absmax scaling (`scale = max(abs(x)) / 127.0`). - Wired the new launcher into the cache update flow. 3. Decode Path (`vtx_graph_backend.py` & `paged_decode_int8.py`): - Bypassed FlashInfer for INT8 decoding. - Wired in the custom Triton decode kernel (`paged_decode_int8`) that reads the `int8` pages and `float32` scales directly into SRAM, performing fused inline dequantization without allocating temporary full-cache VRAM buffers. - Seamlessly integrated with existing sparse routing indices (`indptr`, `indices`). 4. Prefill Path (`vtx_graph_backend.py` & `paged_prefill_int8.py`): - Implemented an OOM-safe `bf16` fallback for prefill. - Added a new Triton kernel (`dequant_paged_int8_to_bf16`) to dynamically extract and dequantize *only the accessed pages* for the current batch into a tiny, compacted `bf16` buffer. - Modified the FlashInfer `BatchPrefillWithPagedKVCacheWrapper` planner to map over the compacted subset indices, entirely avoiding full-cache dequantization OOMs. --- CLAUDE.md | 88 +++++ examples/verify_algo.py | 27 +- examples/verify_algo.sh | 27 +- examples/verify_algo_quant.sh | 25 ++ vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/triton_kernels/__init__.py | 11 +- .../cache/triton_kernels/paged_decode_int8.py | 355 ++++++++++++++++++ .../triton_kernels/paged_prefill_int8.py | 90 +++++ vortex_torch/cache/triton_kernels/set_kv.py | 87 +++++ 9 files changed, 696 insertions(+), 17 deletions(-) create mode 100644 CLAUDE.md create mode 100644 examples/verify_algo_quant.sh create mode 100644 vortex_torch/cache/triton_kernels/paged_decode_int8.py create mode 100644 vortex_torch/cache/triton_kernels/paged_prefill_int8.py diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..db54c757 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,88 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. + +## Build & Install + +```bash +# Install SGLang dependency (custom fork in third_party/) +cd third_party/sglang && bash install.sh && cd ../../ + +# Install Vortex (editable mode, compiles CUDA extensions for SM_89/SM_90) +pip install -e . +``` + +Requires Python >=3.10, torch>=2.7. CUDA extensions are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu). + +## Running Examples + +```bash +# Single algorithm verification against SGLang +python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention + +# Batch test multiple algorithms +bash examples/verify_algo.sh +``` + +## Building Documentation + +```bash +make -C docs html +``` + +Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. + +## Architecture + +### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) + +All sparse attention algorithms inherit from `vFlow` and implement three methods: + +- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. +- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. +- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. + +Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. + +### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) + +Operators (`vOp` subclasses) run in two modes: +- **Profile mode**: Pre-compute output shapes and allocate buffers +- **Execute mode**: Perform actual GPU computation + +Operators are split into two parallel hierarchies: +- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load +- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup + +Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. + +### Tensor Format (`vortex_torch/abs/tensor.py`) + +`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. + +### Context System (`vortex_torch/abs/context_base.py`) + +`ContextBase` carries per-step runtime state. Specialized as: +- `Indexer.Context`: Page layout, head config, hardware info +- `Cache.Context`: Page size, total pages, model info + +### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) + +- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) +- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation +- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds + +### SGLang Integration + +Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, and transpose operations. + +## Key Conventions + +- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` +- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) +- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` +- **Branch**: Main development is on `v1` diff --git a/examples/verify_algo.py b/examples/verify_algo.py index e290a81b..f4185983 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -54,7 +54,8 @@ def verify_algos( vortex_module_name: str = "gqa_block_sparse_attention", model_name: str = "Qwen/Qwen3-1.7B", sparse_attention: bool = True, -mem: float = 0.8 +mem: float = 0.8, +kv_cache_dtype: str = "auto", ): llm = sgl.Engine(model_path=model_name, @@ -69,10 +70,11 @@ def verify_algos( vortex_layers_skip=list(range(1)), vortex_module_name=vortex_module_name, vortex_max_seq_lens=12288, - mem_fraction_static=mem + mem_fraction_static=mem, + kv_cache_dtype=kv_cache_dtype, ) - with open("examples/amc23.jsonl", "r", encoding="utf-8") as f: + with open("amc23.jsonl", "r", encoding="utf-8") as f: requests = [json.loads(line) for line in f] requests = requests * trials @@ -110,6 +112,14 @@ def verify_algos( "num_tokens": item["meta_info"]["completion_tokens"] } ) + # --- Per-question debug output --- + print(f"[Q{len(results):03d}] score={float(result):.1f} " + f"tokens={item['meta_info']['completion_tokens']} " + f"latency={item['meta_info']['e2e_latency']:.2f}s " + f"gold={golds[0]}") + print(f" question: {data['question'][:120]}...") + print(f" prediction: {predictions[:200]}...") + print() total_accuracy = 0.0 @@ -203,6 +213,14 @@ def parse_args(): default=0.8, help="memory fraction in sglang", ) + + parser.add_argument( + "--kv-cache-dtype", + type=str, + default="auto", + choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], + help='KV cache dtype (default: "auto").', + ) return parser.parse_args() if __name__ == "__main__": @@ -215,7 +233,8 @@ def parse_args(): vortex_module_name=args.vortex_module_name, model_name=args.model_name, sparse_attention=not(args.full_attention), - mem=args.mem + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, ) print(summary) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 17c2a5ed..d80f09a5 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,17 +1,24 @@ #!/usr/bin/env bash set -e +export CUDA_VISIBLE_DEVICES=1 sparse_algos=( - + "block_sparse_attention" ) -for algo in "${sparse_algos[@]}"; do - echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" - python examples/verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --mem 0.7 -done +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype bf16" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh new file mode 100644 index 00000000..4cf1366b --- /dev/null +++ b/examples/verify_algo_quant.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=2 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype int8 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done \ No newline at end of file diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index eddfa464..b8865596 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,11 +29,12 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher __all__ = [ "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 6bf6dfcd..2d6384ff 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,4 +1,11 @@ -from .set_kv import set_kv_buffer_launcher +from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .paged_decode_int8 import paged_decode_int8 +from .paged_prefill_int8 import dequant_paged_int8_to_bf16 -__all__ = ["set_kv_buffer_launcher"] +__all__ = [ + "set_kv_buffer_launcher", + "set_kv_buffer_int8_launcher", + "paged_decode_int8", + "dequant_paged_int8_to_bf16", +] diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py new file mode 100644 index 00000000..480c787e --- /dev/null +++ b/vortex_torch/cache/triton_kernels/paged_decode_int8.py @@ -0,0 +1,355 @@ +""" +Custom Triton paged decode attention kernel for int8 KV cache. + +Loads int8 K/V pages with per-token float32 scales, dequantizes inline in SRAM, +and computes standard multi-head attention with online softmax. + +Adapted from SGLang's decode_attention.py for use with Vortex's paged layout +where each KV head is treated as a separate "batch" entry. +""" + +import torch +import triton +import triton.language as tl + +_MIN_BLOCK_KV = 32 + + +@triton.jit +def tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_int8_stage1( + Q, # [batch, num_qo_heads, head_dim] bf16 + K_Buffer, # int8 paged: flat + V_Buffer, # int8 paged: flat + K_Scale_Buffer, # float32: flat (one scale per token slot) + V_Scale_Buffer, # float32: flat + sm_scale, + kv_indptr, # [batch + 1] int32, page-level + kv_indices, # page indices + last_page_len, # [batch] int32, tokens valid in last page + Att_Out, # [batch, num_qo_heads, max_kv_splits, head_dim] + Att_Lse, # [batch, num_qo_heads, max_kv_splits] + num_kv_splits, # [batch] int32 + stride_qbs, + stride_qh, + stride_buf_kbs, # stride per token in K_Buffer (= head_dim) + stride_buf_vbs, # stride per token in V_Buffer (= head_dim) + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """ + Stage 1: For each (batch, head, kv_split), compute partial attention output and LSE. + + kv_indptr is page-level. Total tokens for batch i: + (num_pages - 1) * PAGE_SIZE + last_page_len[i] + """ + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + cur_last_page_len = tl.load(last_page_len + cur_batch) + # Correct token count accounting for partial last page + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_kv_end + + # Convert token offsets to page_id + in-page offset + page_indices_in_seq = offs_n // PAGE_SIZE + in_page_offsets = offs_n % PAGE_SIZE + + # Load page indices from kv_indices (physical page IDs) + page_ids = tl.load( + kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, + mask=mask_n, + other=0, + ) + + # Flat token location: physical_page * PAGE_SIZE + in_page_offset + kv_loc = page_ids * PAGE_SIZE + in_page_offsets + + # Load int8 K and dequantize + offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] + k_int8 = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], + other=0, + ).to(tl.float32) + + k_scale = tl.load( + K_Scale_Buffer + kv_loc, + mask=mask_n, + other=1.0, + ) + k = k_int8 * k_scale[:, None] + + # Compute QK + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_n, qk, float("-inf")) + + # Load int8 V and dequantize + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] + v_int8 = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], + other=0, + ).to(tl.float32) + + v_scale = tl.load( + V_Scale_Buffer + kv_loc, + mask=mask_n, + other=1.0, + ) + v = v_int8 * v_scale[:, None] + + # Online softmax accumulation + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=mask_dv, + ) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store( + Att_Lse + offs_mid_o_1, + e_max + tl.log(e_sum), + ) + + +@triton.jit +def _fwd_kernel_int8_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + last_page_len, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Stage 2: Reduce split outputs via log-sum-exp merge.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def paged_decode_int8( + q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 + k_buffer: torch.Tensor, # int8 paged K cache + v_buffer: torch.Tensor, # int8 paged V cache + k_scale_buffer: torch.Tensor, # float32 scale for K + v_scale_buffer: torch.Tensor, # float32 scale for V + o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output + kv_indptr: torch.Tensor, # [batch + 1] int32, page-level + kv_indices: torch.Tensor, # page indices + last_page_len: torch.Tensor, # [batch] int32 + num_kv_splits: torch.Tensor, # [batch] int32 + max_kv_splits: int, + sm_scale: float, + page_size: int, + logit_cap: float = 0.0, +): + """ + Paged decode attention with int8 KV cache and inline dequantization. + + kv_indptr is page-level. last_page_len specifies valid tokens in the last page + for each batch entry. Total tokens = (num_pages - 1) * page_size + last_page_len. + """ + batch = q.shape[0] + head_num = q.shape[1] + Lk = q.shape[2] + Lv = Lk + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + BLOCK_N = 64 + MAX_KV_SPLITS = max_kv_splits + + kv_group_num = head_num + + num_warps = 4 if kv_group_num == 1 else 2 + + # Intermediate buffers for split reduction + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, + device=q.device, + ) + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, + device=q.device, + ) + + stride_buf_kbs = k_buffer.shape[-1] + stride_buf_vbs = v_buffer.shape[-1] + + grid_stage1 = (batch, head_num, MAX_KV_SPLITS) + _fwd_kernel_int8_stage1[grid_stage1]( + q, + k_buffer, + v_buffer, + k_scale_buffer, + v_scale_buffer, + sm_scale, + kv_indptr, + kv_indices, + last_page_len, + att_out, + att_lse, + num_kv_splits, + q.stride(0), + q.stride(1), + stride_buf_kbs, + stride_buf_vbs, + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK_N, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, + Lv=Lv, + PAGE_SIZE=page_size, + ) + + grid_stage2 = (batch, head_num) + _fwd_kernel_int8_stage2[grid_stage2]( + att_out, + att_lse, + o, + kv_indptr, + last_page_len, + num_kv_splits, + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), + o.stride(0), + o.stride(1), + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + PAGE_SIZE=page_size, + num_warps=4, + num_stages=2, + ) diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py new file mode 100644 index 00000000..75c38574 --- /dev/null +++ b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py @@ -0,0 +1,90 @@ +""" +OOM-safe bf16 fallback for int8 KV-cache prefill. + +Instead of implementing full 2D-tiled Triton prefill with int8 dequantization, +this module dequantizes only the accessed KV pages into a compact temporary +bf16 buffer and remaps indices so FlashInfer can operate on the compact buffer. + +This avoids dequantizing the entire global cache buffer. +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _dequant_pages_kernel( + src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat + src_scale, # float32 scale buffer [num_pages, page_size, 1] flat + dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat + page_indices, # int32 [num_accessed_pages] — which global pages to dequant + NUM_PAGES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Dequantize selected int8 pages to bf16 compact buffer.""" + page_idx = tl.program_id(0) # index into page_indices + token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + # Source: global_page_id * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims + src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + val_int8 = tl.load(src_int8 + src_offset, mask=mask_dim, other=0).to(tl.float32) + + # Scale: global_page_id * PAGE_SIZE + token_idx + scale_offset = global_page_id * PAGE_SIZE + token_idx + scale = tl.load(src_scale + scale_offset) + + val_bf16 = (val_int8 * scale).to(tl.bfloat16) + + # Destination: page_idx * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims + dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims + tl.store(dst_bf16 + dst_offset, val_bf16, mask=mask_dim) + + +def dequant_paged_int8_to_bf16( + src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] + src_scale: torch.Tensor, # float32 [num_pages, page_size, 1] + page_indices: torch.Tensor, # int32 [num_accessed_pages] + page_size: int, + head_dim: int, +) -> torch.Tensor: + """ + Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. + + Returns: + bf16 tensor of shape [num_accessed_pages, page_size, head_dim] + """ + num_accessed_pages = page_indices.shape[0] + if num_accessed_pages == 0: + return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) + + dst_bf16 = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src_int8.device, + ) + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_accessed_pages, page_size) + _dequant_pages_kernel[grid]( + src_int8, + src_scale, + dst_bf16, + page_indices, + NUM_PAGES=num_accessed_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + ) + + return dst_bf16 diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index cfa3cab2..43184280 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -36,6 +36,93 @@ def set_kv_buffer_kernel( tl.store(dst_v_ptr, src_v) +@triton.jit +def set_kv_buffer_int8_kernel( + k_cache, # int8 paged K cache + v_cache, # int8 paged V cache + k_scale_cache, # float32 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # float32 per-token V scale [num_pages, page_size, 1] + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr +): + """Quantize bf16 K/V to int8 with per-token absmax scaling and write to paged buffers.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Compute per-token absmax scale: scale = absmax / 127 + absmax_k = tl.max(tl.abs(src_k), axis=0) + absmax_v = tl.max(tl.abs(src_v), axis=0) + # Avoid division by zero + scale_k = absmax_k / 127.0 + 1e-10 + scale_v = absmax_v / 127.0 + 1e-10 + + # Quantize to int8: round(x / scale), clamp to [-128, 127] + q_k = tl.extra.cuda.libdevice.rint(src_k / scale_k) + q_k = tl.minimum(tl.maximum(q_k, -128.0), 127.0).to(tl.int8) + q_v = tl.extra.cuda.libdevice.rint(src_v / scale_v) + q_v = tl.minimum(tl.maximum(q_v, -128.0), 127.0).to(tl.int8) + + # Compute paged destination offset (same layout as bf16 kernel) + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write int8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + # Write per-token scales: shape [num_pages, page_size, 1] + # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) + scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset + tl.store(k_scale_cache + scale_offset, scale_k) + tl.store(v_scale_cache + scale_offset, scale_v) + + +def set_kv_buffer_int8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + k_scale_cache: torch.Tensor, + v_scale_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int +): + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_int8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, + v_cache, + k_scale_cache, + v_scale_cache, + new_k, + new_v, + loc, + NUM_KV_HEAD, + NNZ, + HEAD_DIM, + page_size + ) + + def set_kv_buffer_launcher( k_cache: torch.Tensor, v_cache: torch.Tensor, From 1f52772d451ce0e8663784c1052475b4f7b2618f Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 23 Feb 2026 03:19:56 +0000 Subject: [PATCH 02/24] 1. Add support for pro 6000. 2. Correction for vortex --- CLAUDE.md | 53 +++++++++++ setup.py | 5 +- third_party/sglang | 2 +- vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/triton_kernels/__init__.py | 3 +- .../cache/triton_kernels/paged_decode_int8.py | 42 +++++---- .../triton_kernels/paged_prefill_int8.py | 94 +++++++++++++++++-- vortex_torch/cache/triton_kernels/set_kv.py | 10 +- 8 files changed, 178 insertions(+), 34 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index db54c757..1593d611 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -86,3 +86,56 @@ Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch) - **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) - **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` - **Branch**: Main development is on `v1` + +## Workflow Orchestration + +### 1. Plan Node Default +- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) +- If something goes sideways, STOP and re-plan immediately - don't keep pushing +- Use plan mode for verification steps, not just building +- Write detailed specs upfront to reduce ambiguity + +### 2. Subagent Strategy +- Use subagents liberally to keep main context window clean +- Offload research, exploration, and parallel analysis to subagents +- For complex problems, throw more compute at it via subagents +- One tack per subagent for focused execution + +### 3. Self-Improvement Loop +- After ANY correction from the user: update `tasks/lessons.md` with the pattern +- Write rules for yourself that prevent the same mistake +- Ruthlessly iterate on these lessons until mistake rate drops +- Review lessons at session start for relevant project + +### 4. Verification Before Done +- Never mark a task complete without proving it works +- Diff behavior between main and your changes when relevant +- Ask yourself: "Would a staff engineer approve this?" +- Run tests, check logs, demonstrate correctness + +### 5. Demand Elegance (Balanced) +- For non-trivial changes: pause and ask "is there a more elegant way?" +- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" +- Skip this for simple, obvious fixes - don't over-engineer +- Challenge your own work before presenting it + +### 6. Autonomous Bug Fixing +- When given a bug report: just fix it. Don't ask for hand-holding +- Point at logs, errors, failing tests - then resolve them +- Zero context switching required from the user +- Go fix failing CI tests without being told how + +## Task Management + +1. **Plan First**: Write plan to `tasks/todo.md` with checkable items +2. **Verify Plan**: Check in before starting implementation +3. **Track Progress**: Mark items complete as you go +4. **Explain Changes**: High-level summary at each step +5. **Document Results**: Add review section to `tasks/todo.md` +6. **Capture Lessons**: Update `tasks/lessons.md` after corrections + +## Core Principles + +- **Simplicity First**: Make every change as simple as possible. Impact minimal code. +- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. +- **Minimat Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/setup.py b/setup.py index e2723268..6efeebe5 100644 --- a/setup.py +++ b/setup.py @@ -23,8 +23,11 @@ 'cxx': ['-O3'], 'nvcc': [ '-O3', + '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', - '-gencode=arch=compute_90,code=sm_90' + '-gencode=arch=compute_90,code=sm_90', + '-gencode=arch=compute_100a,code=sm_100a', + '-gencode=arch=compute_120,code=sm_120' ], }, ), diff --git a/third_party/sglang b/third_party/sglang index e383c0fd..9672e9a7 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit e383c0fdd551f74f24d247e8a7cc8013861949ad +Subproject commit 9672e9a7f90bcb782ccdfb2ee123ede7f2ef5d17 diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index b8865596..b32d8bcc 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,12 +29,13 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "dequant_paged_int8_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 2d6384ff..a18067e9 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,11 +1,12 @@ from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher from .paged_decode_int8 import paged_decode_int8 -from .paged_prefill_int8 import dequant_paged_int8_to_bf16 +from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", "paged_decode_int8", "dequant_paged_int8_to_bf16", + "dequant_paged_int8_to_bf16_inplace", ] diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py index 480c787e..4f33cd45 100644 --- a/vortex_torch/cache/triton_kernels/paged_decode_int8.py +++ b/vortex_torch/cache/triton_kernels/paged_decode_int8.py @@ -25,8 +25,8 @@ def _fwd_kernel_int8_stage1( Q, # [batch, num_qo_heads, head_dim] bf16 K_Buffer, # int8 paged: flat V_Buffer, # int8 paged: flat - K_Scale_Buffer, # float32: flat (one scale per token slot) - V_Scale_Buffer, # float32: flat + K_Scale_Buffer, # fp16: flat (one scale per token slot) + V_Scale_Buffer, # fp16: flat sm_scale, kv_indptr, # [batch + 1] int32, page-level kv_indices, # page indices @@ -118,7 +118,7 @@ def _fwd_kernel_int8_stage1( K_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, - ) + ).to(tl.float32) k = k_int8 * k_scale[:, None] # Compute QK @@ -142,7 +142,7 @@ def _fwd_kernel_int8_stage1( V_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, - ) + ).to(tl.float32) v = v_int8 * v_scale[:, None] # Online softmax accumulation @@ -251,8 +251,8 @@ def paged_decode_int8( q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 k_buffer: torch.Tensor, # int8 paged K cache v_buffer: torch.Tensor, # int8 paged V cache - k_scale_buffer: torch.Tensor, # float32 scale for K - v_scale_buffer: torch.Tensor, # float32 scale for V + k_scale_buffer: torch.Tensor, # fp16 scale for K + v_scale_buffer: torch.Tensor, # fp16 scale for V o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output kv_indptr: torch.Tensor, # [batch + 1] int32, page-level kv_indices: torch.Tensor, # page indices @@ -262,6 +262,8 @@ def paged_decode_int8( sm_scale: float, page_size: int, logit_cap: float = 0.0, + att_out: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits, Lv] + att_lse: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits] ): """ Paged decode attention with int8 KV cache and inline dequantization. @@ -283,17 +285,23 @@ def paged_decode_int8( num_warps = 4 if kv_group_num == 1 else 2 - # Intermediate buffers for split reduction - att_out = torch.empty( - (batch, head_num, MAX_KV_SPLITS, Lv), - dtype=torch.float32, - device=q.device, - ) - att_lse = torch.empty( - (batch, head_num, MAX_KV_SPLITS), - dtype=torch.float32, - device=q.device, - ) + # Use pre-allocated buffers if provided, otherwise allocate + if att_out is None: + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, + device=q.device, + ) + else: + att_out = att_out[:batch] + if att_lse is None: + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, + device=q.device, + ) + else: + att_lse = att_lse[:batch] stride_buf_kbs = k_buffer.shape[-1] stride_buf_vbs = v_buffer.shape[-1] diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py index 75c38574..89279833 100644 --- a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py +++ b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py @@ -16,7 +16,7 @@ @triton.jit def _dequant_pages_kernel( src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat - src_scale, # float32 scale buffer [num_pages, page_size, 1] flat + src_scale, # fp16 scale buffer [num_pages, page_size, 1] flat dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat page_indices, # int32 [num_accessed_pages] — which global pages to dequant NUM_PAGES: tl.constexpr, @@ -41,7 +41,7 @@ def _dequant_pages_kernel( # Scale: global_page_id * PAGE_SIZE + token_idx scale_offset = global_page_id * PAGE_SIZE + token_idx - scale = tl.load(src_scale + scale_offset) + scale = tl.load(src_scale + scale_offset).to(tl.float32) val_bf16 = (val_int8 * scale).to(tl.bfloat16) @@ -52,26 +52,35 @@ def _dequant_pages_kernel( def dequant_paged_int8_to_bf16( src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] - src_scale: torch.Tensor, # float32 [num_pages, page_size, 1] + src_scale: torch.Tensor, # fp16 [num_pages, page_size, 1] page_indices: torch.Tensor, # int32 [num_accessed_pages] page_size: int, head_dim: int, + out: torch.Tensor = None, # optional pre-allocated bf16 [>=num_accessed_pages, page_size, head_dim] ) -> torch.Tensor: """ Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. + If `out` is provided, writes into it (must have room for num_accessed_pages). + Otherwise allocates a new buffer. + Returns: bf16 tensor of shape [num_accessed_pages, page_size, head_dim] """ num_accessed_pages = page_indices.shape[0] if num_accessed_pages == 0: + if out is not None: + return out[:0] return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) - dst_bf16 = torch.empty( - (num_accessed_pages, page_size, head_dim), - dtype=torch.bfloat16, - device=src_int8.device, - ) + if out is not None: + dst_bf16 = out[:num_accessed_pages] + else: + dst_bf16 = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src_int8.device, + ) BLOCK_DIM = triton.next_power_of_2(head_dim) @@ -88,3 +97,72 @@ def dequant_paged_int8_to_bf16( ) return dst_bf16 + + +@triton.jit +def _dequant_pages_inplace_kernel( + src_int8, # int8 paged buffer flat + src_scale, # scale buffer flat (one scale per token slot) + dst_bf16, # bf16 destination buffer (same page layout as src) + page_indices, # int32 [num_pages] — which global pages to dequant + NUM_PAGES: tl.constexpr, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, +): + """Dequantize selected int8 pages to bf16, writing to the SAME page positions in dst.""" + page_idx = tl.program_id(0) # index into page_indices + token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + # Source and destination use the SAME offset (in-place layout) + offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + val_int8 = tl.load(src_int8 + offset, mask=mask_dim, other=0).to(tl.float32) + + scale_offset = global_page_id * PAGE_SIZE + token_idx + scale = tl.load(src_scale + scale_offset).to(tl.float32) + + val_bf16 = (val_int8 * scale).to(tl.bfloat16) + + # Write to the SAME page position in dst (not compacted) + tl.store(dst_bf16 + offset, val_bf16, mask=mask_dim) + + +def dequant_paged_int8_to_bf16_inplace( + src_int8: torch.Tensor, # int8 paged cache (flat) + src_scale: torch.Tensor, # fp16 scale buffer (flat) + dst_bf16: torch.Tensor, # bf16 destination (same shape as src_int8) + page_indices: torch.Tensor, # int32 [num_pages] — which pages to dequant + page_size: int, + head_dim: int, +) -> None: + """ + Dequantize selected pages from int8 cache to bf16 IN-PLACE. + + Unlike dequant_paged_int8_to_bf16 (which compacts into a dense buffer), + this writes to the SAME page positions in dst_bf16, preserving the paged layout. + Used to populate the bf16 working buffer for forward_cache (centroid computation). + """ + num_pages = page_indices.shape[0] + if num_pages == 0: + return + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_pages, page_size) + _dequant_pages_inplace_kernel[grid]( + src_int8, + src_scale, + dst_bf16, + page_indices, + NUM_PAGES=num_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index 43184280..2a2c785d 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -40,8 +40,8 @@ def set_kv_buffer_kernel( def set_kv_buffer_int8_kernel( k_cache, # int8 paged K cache v_cache, # int8 paged V cache - k_scale_cache, # float32 per-token K scale [num_pages, page_size, 1] - v_scale_cache, # float32 per-token V scale [num_pages, page_size, 1] + k_scale_cache, # fp16 per-token K scale [num_pages, page_size, 1] + v_scale_cache, # fp16 per-token V scale [num_pages, page_size, 1] new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] loc, # int64 token positions @@ -87,11 +87,11 @@ def set_kv_buffer_int8_kernel( tl.store(dst_k_ptr, q_k) tl.store(dst_v_ptr, q_v) - # Write per-token scales: shape [num_pages, page_size, 1] + # Write per-token scales (fp16): shape [num_pages, page_size, 1] # Layout: page_id * PAGE_SIZE + in_page_offset (flat per-head, one scale per token per head) scale_offset = (page_id * NUM_KV_HEAD + head_id) * PAGE_SIZE + in_page_offset - tl.store(k_scale_cache + scale_offset, scale_k) - tl.store(v_scale_cache + scale_offset, scale_v) + tl.store(k_scale_cache + scale_offset, scale_k.to(tl.float16)) + tl.store(v_scale_cache + scale_offset, scale_v.to(tl.float16)) def set_kv_buffer_int8_launcher( From 584f23355412a4464215ccb15d06758ad1b2762c Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 23 Feb 2026 07:07:28 +0000 Subject: [PATCH 03/24] 1. Correction on int8 (maximize memory occupation) 2. Implement fp8 quantization. --- CLAUDE.md | 141 -------- examples/verify_algo.py | 14 +- examples/verify_algo_fp8.sh | 25 ++ ...rify_algo_quant.sh => verify_algo_int8.sh} | 0 third_party/sglang | 2 +- vortex_torch/cache/__init__.py | 3 +- vortex_torch/cache/context.py | 20 +- vortex_torch/cache/reduce.py | 6 +- vortex_torch/cache/triton_kernels/__init__.py | 3 +- .../cache/triton_kernels/reduce_impl.py | 328 +++++++++--------- vortex_torch/cache/triton_kernels/set_kv.py | 97 +++++- 11 files changed, 319 insertions(+), 320 deletions(-) delete mode 100644 CLAUDE.md create mode 100755 examples/verify_algo_fp8.sh rename examples/{verify_algo_quant.sh => verify_algo_int8.sh} (100%) diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 1593d611..00000000 --- a/CLAUDE.md +++ /dev/null @@ -1,141 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. - -## Build & Install - -```bash -# Install SGLang dependency (custom fork in third_party/) -cd third_party/sglang && bash install.sh && cd ../../ - -# Install Vortex (editable mode, compiles CUDA extensions for SM_89/SM_90) -pip install -e . -``` - -Requires Python >=3.10, torch>=2.7. CUDA extensions are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu). - -## Running Examples - -```bash -# Single algorithm verification against SGLang -python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention - -# Batch test multiple algorithms -bash examples/verify_algo.sh -``` - -## Building Documentation - -```bash -make -C docs html -``` - -Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. - -## Architecture - -### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) - -All sparse attention algorithms inherit from `vFlow` and implement three methods: - -- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. -- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. -- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. - -Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. - -### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) - -Operators (`vOp` subclasses) run in two modes: -- **Profile mode**: Pre-compute output shapes and allocate buffers -- **Execute mode**: Perform actual GPU computation - -Operators are split into two parallel hierarchies: -- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load -- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup - -Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. - -### Tensor Format (`vortex_torch/abs/tensor.py`) - -`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. - -### Context System (`vortex_torch/abs/context_base.py`) - -`ContextBase` carries per-step runtime state. Specialized as: -- `Indexer.Context`: Page layout, head config, hardware info -- `Cache.Context`: Page size, total pages, model info - -### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) - -- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) -- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation -- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds - -### SGLang Integration - -Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, and transpose operations. - -## Key Conventions - -- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` -- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) -- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` -- **Branch**: Main development is on `v1` - -## Workflow Orchestration - -### 1. Plan Node Default -- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) -- If something goes sideways, STOP and re-plan immediately - don't keep pushing -- Use plan mode for verification steps, not just building -- Write detailed specs upfront to reduce ambiguity - -### 2. Subagent Strategy -- Use subagents liberally to keep main context window clean -- Offload research, exploration, and parallel analysis to subagents -- For complex problems, throw more compute at it via subagents -- One tack per subagent for focused execution - -### 3. Self-Improvement Loop -- After ANY correction from the user: update `tasks/lessons.md` with the pattern -- Write rules for yourself that prevent the same mistake -- Ruthlessly iterate on these lessons until mistake rate drops -- Review lessons at session start for relevant project - -### 4. Verification Before Done -- Never mark a task complete without proving it works -- Diff behavior between main and your changes when relevant -- Ask yourself: "Would a staff engineer approve this?" -- Run tests, check logs, demonstrate correctness - -### 5. Demand Elegance (Balanced) -- For non-trivial changes: pause and ask "is there a more elegant way?" -- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" -- Skip this for simple, obvious fixes - don't over-engineer -- Challenge your own work before presenting it - -### 6. Autonomous Bug Fixing -- When given a bug report: just fix it. Don't ask for hand-holding -- Point at logs, errors, failing tests - then resolve them -- Zero context switching required from the user -- Go fix failing CI tests without being told how - -## Task Management - -1. **Plan First**: Write plan to `tasks/todo.md` with checkable items -2. **Verify Plan**: Check in before starting implementation -3. **Track Progress**: Mark items complete as you go -4. **Explain Changes**: High-level summary at each step -5. **Document Results**: Add review section to `tasks/todo.md` -6. **Capture Lessons**: Update `tasks/lessons.md` after corrections - -## Core Principles - -- **Simplicity First**: Make every change as simple as possible. Impact minimal code. -- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. -- **Minimat Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index f4185983..9958b7e3 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -113,13 +113,13 @@ def verify_algos( } ) # --- Per-question debug output --- - print(f"[Q{len(results):03d}] score={float(result):.1f} " - f"tokens={item['meta_info']['completion_tokens']} " - f"latency={item['meta_info']['e2e_latency']:.2f}s " - f"gold={golds[0]}") - print(f" question: {data['question'][:120]}...") - print(f" prediction: {predictions[:200]}...") - print() + # print(f"[Q{len(results):03d}] score={float(result):.1f} " + # f"tokens={item['meta_info']['completion_tokens']} " + # f"latency={item['meta_info']['e2e_latency']:.2f}s " + # f"gold={golds[0]}") + # print(f" question: {data['question'][:120]}...") + # print(f" prediction: {predictions[:200]}...") + # print() total_accuracy = 0.0 diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh new file mode 100755 index 00000000..7f266e5e --- /dev/null +++ b/examples/verify_algo_fp8.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=3 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype fp8_e4m3 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_int8.sh similarity index 100% rename from examples/verify_algo_quant.sh rename to examples/verify_algo_int8.sh diff --git a/third_party/sglang b/third_party/sglang index 9672e9a7..7105719f 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 9672e9a7f90bcb782ccdfb2ee123ede7f2ef5d17 +Subproject commit 7105719f0a2ac464ee7ffdc0a899fa6a656656a2 diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index b32d8bcc..8c4d0e0f 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,12 +29,13 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, dequant_paged_int8_to_bf16_inplace +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", "dequant_paged_int8_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index ae2dd5c2..dd1bd024 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -10,17 +10,21 @@ class Context(ContextBase): """ __slots__ = ContextBase.__slots__ + ( - + #page infomation "max_new_tokens_per_batch", "page_size", "total_num_pages", - + #model infomation "head_dim", "head_num", - + # auxilary memory in graph "_aux_total_bytes", - - "_aux_total_flops" + + "_aux_total_flops", + + # FP8 quantization: fp8_type (0=none, 1=e4m3, 2=e5m2), kv_scale (per-tensor) + "fp8_type", + "kv_scale", ) @@ -36,7 +40,11 @@ def __init__(self) -> None: elif name == "_aux_total_flops": object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": - object.__setattr__(self, name, Mode.profile) + object.__setattr__(self, name, Mode.profile) + elif name == "fp8_type": + object.__setattr__(self, name, 0) # 0 = no fp8 (bf16 default) + elif name == "kv_scale": + object.__setattr__(self, name, 1.0) # identity scale for bf16 else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 3c4edf2f..5800458c 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,8 +345,10 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, fp8_type, scale) + fp8_type = getattr(ctx, 'fp8_type', 0) + scale = getattr(ctx, 'kv_scale', 1.0) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, fp8_type, scale) return output diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index a18067e9..009e728e 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,10 +1,11 @@ -from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher +from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher from .paged_decode_int8 import paged_decode_int8 from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", + "set_kv_buffer_fp8_launcher", "paged_decode_int8", "dequant_paged_int8_to_bf16", "dequant_paged_int8_to_bf16_inplace", diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9921e082..9670acd0 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -4,6 +4,16 @@ from ..context import Context from ...utils import ReduceType + +# --------------------------------------------------------------------------- +# Helper: Load a page block from src_ptr, handling bf16 or fp8-stored-as-uint8. +# FP8_TYPE == 0 -> bf16 pointer, load normally +# FP8_TYPE == 1 -> uint8 pointer, bitcast to float8e4nv, dequant with scale +# FP8_TYPE == 2 -> uint8 pointer, bitcast to float8e5, dequant with scale +# All paths return a float32 tensor ready for reduction. +# --------------------------------------------------------------------------- + + @triton.jit def reduce_pp_kernel( x, output, loc, @@ -12,9 +22,11 @@ def reduce_pp_kernel( NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm -DIM: tl.constexpr # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -): - +DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 +FP8_TYPE: tl.constexpr, # 0: bf16, 1: e4m3, 2: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +): + token_id = tl.program_id(0) head_id = tl.program_id(1) @@ -29,7 +41,15 @@ def reduce_pp_kernel( rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - page_block = tl.load(src_ptr) + + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -40,7 +60,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) @@ -55,7 +75,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) @@ -71,11 +91,13 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -85,7 +107,9 @@ def reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -97,11 +121,13 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -111,7 +137,9 @@ def _reduce_pp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -119,84 +147,67 @@ def _reduce_pp( @triton.jit def reduce_rp_kernel( x, output, loc, - x_D0: tl.constexpr, # rows per token-page - x_D1: tl.constexpr, # cols per token-page + x_D0: tl.constexpr, + x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, - REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) - DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + REDUCE_TYPE: tl.constexpr, + DIM: tl.constexpr, + FP8_TYPE: tl.constexpr, + scale, ): - - # Program IDs: - # pid0 = token index (0 .. num_tokens-1) - # pid1 = head index (0 .. NUM_KV_HEAD-1) + token_id = tl.program_id(0) head_id = tl.program_id(1) - # Load the absolute position of this token (used to map to page index). token_position = tl.load(loc + token_id) - # Only the last token of a page triggers the reduction. if (token_position + 1) % PAGE_SIZE != 0: return - # Output page index: - # Logical page = token_position // PAGE_SIZE - # One vector per head, so linearize by NUM_KV_HEAD. page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id - - # Input layout is [num_tokens, num_heads, x_D0, x_D1] (row-major). - # For this token/head, compute the base element offset in `x`. x_offset = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - # Build 2D indices within a page (row-major addressing). - rows = tl.arange(0, x_D0)[:, None] # shape [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # shape [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_offset + rows * x_D1 + cols - # Load the full page block for this (token_id, head_id). - # Assumes the page is full; add masks here if you have partial tiles. - page_block = tl.load(src_ptr) + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) - # Reduction: if DIM == 1: - # Reduce over rows (axis=0) -> output vector length x_D1 (per-column reduce). - if REDUCE_TYPE == 0: # Mean - # NOTE: precision-sensitive workloads may want fp32 accumulation: - # s = tl.sum(page_block.to(tl.float32), axis=0) - # reduce_vec = (s / x_D0).to(tl.bfloat16) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: # L2Norm (sqrt(sum(x*x))); NOT RMS - # For RMS, use: tl.sqrt(tl.sum(page_block*page_block, axis=0) / x_D0) - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # Write to output: layout [num_pages, x_D1] for DIM==1. dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - # DIM == 2: Reduce over cols (axis=1) -> output vector length x_D0 (per-row reduce). - if REDUCE_TYPE == 0: # Mean - # s = tl.sum(page_block.to(tl.float32), axis=1) - # reduce_vec = (s / x_D1).to(tl.bfloat16) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: # L2Norm (sqrt(sum(x*x))); NOT RMS - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) - def reduce_rp( @@ -206,11 +217,13 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -220,7 +233,9 @@ def reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @@ -232,11 +247,13 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rp_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -246,92 +263,76 @@ def _reduce_rp( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @triton.jit def reduce_pr_kernel( x, output, loc, -x_D0: tl.constexpr, # rows per page -x_D1: tl.constexpr, # cols per page +x_D0: tl.constexpr, +x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +REDUCE_TYPE: tl.constexpr, +DIM: tl.constexpr, +FP8_TYPE: tl.constexpr, +scale, ): - """ - Layouts: - x: [num_pages * NUM_KV_HEAD, x_D0, x_D1] (page-major, row-major inside page) - output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) - - Behavior: - - token_id comes from pid0; head_id comes from pid1. - - Read loc[token_id] to get absolute position; only proceed at page end. - - Map token -> page via page_idx = (token_position // PAGE_SIZE). - - Read the whole page for this (page_idx, head_id), do reduction, - then write a single vector to output at (token_id, head_id, :). - """ - - # --- Program IDs --- - token_id = tl.program_id(0) # [0 .. num_tokens-1] - head_id = tl.program_id(1) # [0 .. NUM_KV_HEAD-1] - - # --- Trigger only at end-of-page token --- + + token_id = tl.program_id(0) + head_id = tl.program_id(1) + token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return - # --- Page indexing for x (page-major) --- - # page linear id across heads page_idx = token_position // PAGE_SIZE page_id = page_idx * NUM_KV_HEAD + head_id - # Base element offset into x for this (page_id, head_id) - # x is laid out as contiguous pages, each page is [x_D0, x_D1] x_offset = page_id * x_D0 * x_D1 - # 2D row-major addressing within the page - rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_offset + rows * x_D1 + cols - # Load the full page block. Assumes full tiles; add masks if needed. - page_block = tl.load(src_ptr) + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_block = tl.load(src_ptr).to(tl.float32) - # --- Reduction & write-out --- if DIM == 1: - # Reduce over rows (axis=0) -> per-column vector, length = x_D1 - if REDUCE_TYPE == 0: # Mean - # For better accuracy you may upcast: tl.sum(page_block.to(tl.float32), axis=0) + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: # L2Norm (NOT RMS) - s = tl.sum(page_block * page_block, axis=0).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=0) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - # output is token-major: [num_tokens, NUM_KV_HEAD, x_D1] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 dst_ptr = output + out_base + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - # DIM == 2: Reduce over cols (axis=1) -> per-row vector, length = x_D0 - if REDUCE_TYPE == 0: # Mean + if REDUCE_TYPE == 0: reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: # L2Norm (NOT RMS) - s = tl.sum(page_block * page_block, axis=1).to(tl.float32) + else: + s = tl.sum(page_block * page_block, axis=1) reduce_vec = tl.sqrt(s).to(tl.bfloat16) - - # output is token-major: [num_tokens, NUM_KV_HEAD, x_D0] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 dst_ptr = output + out_base + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) @@ -344,11 +345,13 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -358,9 +361,11 @@ def reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) - + def _reduce_pr( x: torch.Tensor, output: torch.Tensor, @@ -369,11 +374,13 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_pr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -383,72 +390,68 @@ def _reduce_pr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) @triton.jit def reduce_rr_kernel( x, output, loc, -x_D0: tl.constexpr, # rows per token-page -x_D1: tl.constexpr, # cols per token-page +x_D0: tl.constexpr, +x_D1: tl.constexpr, NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) -DIM: tl.constexpr # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +REDUCE_TYPE: tl.constexpr, +DIM: tl.constexpr, +FP8_TYPE: tl.constexpr, +scale, ): - """ - Layouts: - x: [num_tokens * NUM_KV_HEAD, x_D0, x_D1] (token-major) - output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) - Only the last token of each page performs the reduction and writes to output[token_id, head_id, :]. - """ - - - # program ids - token_id = tl.program_id(0) # 0..num_tokens-1 - head_id = tl.program_id(1) # 0..NUM_KV_HEAD-1 + token_id = tl.program_id(0) + head_id = tl.program_id(1) - # trigger only at end-of-page token token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return - # ---- read from x (token-major) ---- x_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] - cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] + rows = tl.arange(0, x_D0)[:, None] + cols = tl.arange(0, x_D1)[None, :] src_ptr = x + x_base + rows * x_D1 + cols - page_blk = tl.load(src_ptr) # assumes full page; add masks if needed - # ---- reduce ---- + if FP8_TYPE == 1: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale + elif FP8_TYPE == 2: + raw = tl.load(src_ptr) + page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale + else: + page_blk = tl.load(src_ptr).to(tl.float32) + if DIM == 1: - # over rows -> axis=0 -> vector len x_D1 - if REDUCE_TYPE == 0: # Mean - # For better accuracy you may upcast to fp32 before sum. + if REDUCE_TYPE == 0: vec = (tl.sum(page_blk, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: vec = tl.max(page_blk, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: vec = tl.min(page_blk, axis=0).to(tl.bfloat16) - else: # L2Norm (NOT RMS) + else: s = tl.sum(page_blk * page_blk, axis=0) vec = tl.sqrt(s).to(tl.bfloat16) - # ---- write to output (token-major) ---- out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 tl.store(output + out_base + tl.arange(0, x_D1), vec) else: - # DIM == 2: over cols -> axis=1 -> vector len x_D0 - if REDUCE_TYPE == 0: # Mean + if REDUCE_TYPE == 0: vec = (tl.sum(page_blk, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: # Max + elif REDUCE_TYPE == 1: vec = tl.max(page_blk, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: # Min + elif REDUCE_TYPE == 2: vec = tl.min(page_blk, axis=1).to(tl.bfloat16) - else: # L2Norm (NOT RMS) + else: s = tl.sum(page_blk * page_blk, axis=1) vec = tl.sqrt(s).to(tl.bfloat16) @@ -456,7 +459,6 @@ def reduce_rr_kernel( tl.store(output + out_base + tl.arange(0, x_D0), vec) - def reduce_rr( x: torch.Tensor, output: torch.Tensor, @@ -464,11 +466,13 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = ctx.head_num - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -478,9 +482,11 @@ def reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, ) - + def _reduce_rr( x: torch.Tensor, @@ -490,11 +496,13 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, +fp8_type: int = 0, +scale: float = 1.0, ): - + NNZ = loc.shape[0] NUM_KV_HEAD = num_kv_heads - + reduce_rr_kernel[(NNZ, NUM_KV_HEAD)]( x=x, output=output, @@ -504,5 +512,7 @@ def _reduce_rr( NUM_KV_HEAD=NUM_KV_HEAD, PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, - DIM=dim - ) \ No newline at end of file + DIM=dim, + FP8_TYPE=fp8_type, + scale=scale, + ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index 2a2c785d..6b289df3 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -131,11 +131,11 @@ def set_kv_buffer_launcher( loc: torch.LongTensor, page_size: int ): - + NNZ = loc.shape[0] NUM_KV_HEAD = new_k.shape[1] HEAD_DIM = new_k.shape[2] - + set_kv_buffer_kernel[(NNZ, NUM_KV_HEAD)]( k_cache, v_cache, @@ -148,3 +148,96 @@ def set_kv_buffer_launcher( page_size ) + +@triton.jit +def set_kv_buffer_fp8_kernel( + k_cache, # uint8 paged K cache + v_cache, # uint8 paged V cache + new_k, # bf16 input K [NNZ, NUM_KV_HEAD, HEAD_DIM] + new_v, # bf16 input V [NNZ, NUM_KV_HEAD, HEAD_DIM] + loc, # int64 token positions + NUM_KV_HEAD: tl.constexpr, + NNZ: tl.constexpr, + HEAD_DIM: tl.constexpr, + PAGE_SIZE: tl.constexpr, + FP8_TYPE: tl.constexpr, # 1: e4m3 (max=448), 2: e5m2 (max=57344) + k_scale, # float: per-tensor scale for K quantization + v_scale, # float: per-tensor scale for V quantization +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache.""" + token_id = tl.program_id(0) + if token_id >= NNZ: + return + head_id = tl.program_id(1) + dim = tl.arange(0, HEAD_DIM) + + # Load bf16 source values + src_ptr = token_id * NUM_KV_HEAD * HEAD_DIM + head_id * HEAD_DIM + dim + src_k = tl.load(new_k + src_ptr).to(tl.float32) + src_v = tl.load(new_v + src_ptr).to(tl.float32) + + # Scale down: quantized = real_value / scale + inv_k_scale = 1.0 / k_scale + inv_v_scale = 1.0 / v_scale + scaled_k = src_k * inv_k_scale + scaled_v = src_v * inv_v_scale + + # Clamp and cast to fp8, then bitcast to uint8 for storage + if FP8_TYPE == 1: + # e4m3: max = 448.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -448.0), 448.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -448.0), 448.0) + q_k = clamped_k.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e4nv).to(tl.uint8, bitcast=True) + else: + # e5m2: max = 57344.0 + clamped_k = tl.minimum(tl.maximum(scaled_k, -57344.0), 57344.0) + clamped_v = tl.minimum(tl.maximum(scaled_v, -57344.0), 57344.0) + q_k = clamped_k.to(tl.float8e5).to(tl.uint8, bitcast=True) + q_v = clamped_v.to(tl.float8e5).to(tl.uint8, bitcast=True) + + # Compute paged destination offset + token_position = tl.load(loc + token_id) + page_id = token_position // PAGE_SIZE + in_page_offset = token_position % PAGE_SIZE + position_trans = page_id * (PAGE_SIZE * NUM_KV_HEAD) + head_id * PAGE_SIZE + in_page_offset + + # Write uint8 values + dst_k_ptr = k_cache + position_trans * HEAD_DIM + dim + dst_v_ptr = v_cache + position_trans * HEAD_DIM + dim + tl.store(dst_k_ptr, q_k) + tl.store(dst_v_ptr, q_v) + + +def set_kv_buffer_fp8_launcher( + k_cache: torch.Tensor, + v_cache: torch.Tensor, + new_k: torch.Tensor, + new_v: torch.Tensor, + loc: torch.LongTensor, + page_size: int, + k_scale: float, + v_scale: float, + fp8_type: int = 1, +): + """Quantize bf16 K/V to fp8, bitcast to uint8, and scatter into paged cache. + + Args: + fp8_type: 1 for e4m3 (default), 2 for e5m2. + k_scale: per-tensor scale used for K quantization. + v_scale: per-tensor scale used for V quantization. + """ + NNZ = loc.shape[0] + NUM_KV_HEAD = new_k.shape[1] + HEAD_DIM = new_k.shape[2] + + set_kv_buffer_fp8_kernel[(NNZ, NUM_KV_HEAD)]( + k_cache, v_cache, + new_k, new_v, + loc, + NUM_KV_HEAD, NNZ, HEAD_DIM, page_size, + FP8_TYPE=fp8_type, + k_scale=k_scale, + v_scale=v_scale, + ) + From f25fb13e6075111a9b05aedb060a3e7bd346cebd Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 1 Mar 2026 08:02:49 +0000 Subject: [PATCH 04/24] update on parameters for reduce_pp_kernel with quantization --- setup.py | 2 +- vortex_torch/cache/context.py | 12 +- vortex_torch/cache/reduce.py | 7 +- .../cache/triton_kernels/reduce_impl.py | 305 ++++++++++++------ 4 files changed, 224 insertions(+), 102 deletions(-) diff --git a/setup.py b/setup.py index 6efeebe5..f35ddae6 100644 --- a/setup.py +++ b/setup.py @@ -16,7 +16,7 @@ sources=[ 'csrc/register.cc', 'csrc/utils_sglang.cu', - 'csrc/topk.cu' + 'csrc/topk.cu', ], include_dirs=['csrc'], extra_compile_args={ diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index dd1bd024..0e7171cc 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -22,9 +22,11 @@ class Context(ContextBase): "_aux_total_flops", - # FP8 quantization: fp8_type (0=none, 1=e4m3, 2=e5m2), kv_scale (per-tensor) - "fp8_type", + # Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2), + # kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor) + "quant_type", "kv_scale", + "kv_scale_ptr", ) @@ -41,10 +43,12 @@ def __init__(self) -> None: object.__setattr__(self, name, 0) # start from 0 flops elif name == "mode": object.__setattr__(self, name, Mode.profile) - elif name == "fp8_type": - object.__setattr__(self, name, 0) # 0 = no fp8 (bf16 default) + elif name == "quant_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) elif name == "kv_scale": object.__setattr__(self, name, 1.0) # identity scale for bf16 + elif name == "kv_scale_ptr": + object.__setattr__(self, name, None) # per-token scale tensor (int8 only) else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/cache/reduce.py b/vortex_torch/cache/reduce.py index 5800458c..eb94795e 100644 --- a/vortex_torch/cache/reduce.py +++ b/vortex_torch/cache/reduce.py @@ -345,10 +345,11 @@ def execute( ) output = self.output_buffer - # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, fp8_type, scale) - fp8_type = getattr(ctx, 'fp8_type', 0) + # Launch the kernel/implementation: impl(x, output, loc, ctx, dim, reduce_type, quant_type, scale, kv_scale_ptr) + quant_type = getattr(ctx, 'quant_type', 0) scale = getattr(ctx, 'kv_scale', 1.0) - self.impl(x, output, loc, ctx, self.dim, self.reduce_type, fp8_type, scale) + kv_scale_ptr = getattr(ctx, 'kv_scale_ptr', None) + self.impl(x, output, loc, ctx, self.dim, self.reduce_type, quant_type, scale, kv_scale_ptr) return output diff --git a/vortex_torch/cache/triton_kernels/reduce_impl.py b/vortex_torch/cache/triton_kernels/reduce_impl.py index 9670acd0..0146af7b 100644 --- a/vortex_torch/cache/triton_kernels/reduce_impl.py +++ b/vortex_torch/cache/triton_kernels/reduce_impl.py @@ -6,11 +6,12 @@ # --------------------------------------------------------------------------- -# Helper: Load a page block from src_ptr, handling bf16 or fp8-stored-as-uint8. -# FP8_TYPE == 0 -> bf16 pointer, load normally -# FP8_TYPE == 1 -> uint8 pointer, bitcast to float8e4nv, dequant with scale -# FP8_TYPE == 2 -> uint8 pointer, bitcast to float8e5, dequant with scale -# All paths return a float32 tensor ready for reduction. +# Helper: Load a page block from src_ptr, handling bf16 / int8 / fp8-stored-as-uint8. +# QUANT_TYPE == 0 -> bf16 pointer, load normally +# QUANT_TYPE == 1 -> int8 pointer, dequant with per-row scale from kv_scale_ptr +# QUANT_TYPE == 2 -> uint8 pointer, bitcast to float8e4nv, dequant with per-tensor scale +# QUANT_TYPE == 3 -> uint8 pointer, bitcast to float8e5, dequant with per-tensor scale +# All quantised paths return a float32 tensor ready for reduction. # --------------------------------------------------------------------------- @@ -23,8 +24,9 @@ def reduce_pp_kernel( PAGE_SIZE: tl.constexpr, REDUCE_TYPE: tl.constexpr, # 0:Mean, 1:Max, 2:Min, 3:L2Norm DIM: tl.constexpr, # 1: over rows (axis=0) -> len x_D1; 2: over cols (axis=1) -> len x_D0 -FP8_TYPE: tl.constexpr, # 0: bf16, 1: e4m3, 2: e5m2 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): token_id = tl.program_id(0) @@ -42,14 +44,21 @@ def reduce_pp_kernel( cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + # Per-row scales stored at kv_scale_ptr[page_id * x_D0 + row] + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) # [x_D0] + page_block = raw * row_scales[:, None] # broadcast [x_D0, 1] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) if DIM == 1: # reduce over rows -> axis=0 -> length x_D1 @@ -60,7 +69,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=0) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) @@ -75,7 +84,7 @@ def reduce_pp_kernel( elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) else: # L2Norm - s = tl.sum(page_block * page_block, axis=1) + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) @@ -91,8 +100,9 @@ def reduce_pp( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -108,8 +118,9 @@ def reduce_pp( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -121,8 +132,9 @@ def _reduce_pp( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -138,8 +150,9 @@ def _reduce_pp( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -147,69 +160,102 @@ def _reduce_pp( @triton.jit def reduce_rp_kernel( x, output, loc, - x_D0: tl.constexpr, - x_D1: tl.constexpr, + x_D0: tl.constexpr, # rows per token-page + x_D1: tl.constexpr, # cols per token-page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, - REDUCE_TYPE: tl.constexpr, - DIM: tl.constexpr, - FP8_TYPE: tl.constexpr, - scale, + REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) + DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + scale, # float: 1.0 for bf16, kv_scale for fp8 + kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): + # Program IDs: + # pid0 = token index (0 .. num_tokens-1) + # pid1 = head index (0 .. NUM_KV_HEAD-1) token_id = tl.program_id(0) head_id = tl.program_id(1) + # Load the absolute position of this token (used to map to page index). token_position = tl.load(loc + token_id) + # Only the last token of a page triggers the reduction. if (token_position + 1) % PAGE_SIZE != 0: return + # Output page index: + # Logical page = token_position // PAGE_SIZE + # One vector per head, so linearize by NUM_KV_HEAD. page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + + # Input layout is [num_tokens, num_heads, x_D0, x_D1] (row-major). + # For this token/head, compute the base element offset in `x`. x_offset = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + # Build 2D indices within a page (row-major addressing). + rows = tl.arange(0, x_D0)[:, None] # shape [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # shape [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + # Load the full page block for this (token_id, head_id). + # Assumes the page is full; add masks here if you have partial tiles. + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) + # Reduction: if DIM == 1: - if REDUCE_TYPE == 0: + # Reduce over rows (axis=0) -> output vector length x_D1 (per-column reduce). + if REDUCE_TYPE == 0: # Mean + # NOTE: precision-sensitive workloads may want fp32 accumulation: + # s = tl.sum(page_block.to(tl.float32), axis=0) + # reduce_vec = (s / x_D0).to(tl.bfloat16) reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=0) + else: # L2Norm (sqrt(sum(x*x))); NOT RMS + # For RMS, use: tl.sqrt(tl.sum(page_block*page_block, axis=0) / x_D0) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # Write to output: layout [num_pages, x_D1] for DIM==1. dst_ptr = output + page_id * x_D1 + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: Reduce over cols (axis=1) -> output vector length x_D0 (per-row reduce). + if REDUCE_TYPE == 0: # Mean + # s = tl.sum(page_block.to(tl.float32), axis=1) + # reduce_vec = (s / x_D1).to(tl.bfloat16) reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=1) + else: # L2Norm (sqrt(sum(x*x))); NOT RMS + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # Write to output: layout [num_pages, x_D0] for DIM==2. dst_ptr = output + page_id * x_D0 + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) + def reduce_rp( x: torch.Tensor, output: torch.Tensor, @@ -217,8 +263,9 @@ def reduce_rp( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -234,8 +281,9 @@ def reduce_rp( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -247,8 +295,9 @@ def _reduce_rp( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -264,75 +313,110 @@ def _reduce_rp( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @triton.jit def reduce_pr_kernel( x, output, loc, -x_D0: tl.constexpr, -x_D1: tl.constexpr, +x_D0: tl.constexpr, # rows per page +x_D1: tl.constexpr, # cols per page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, -DIM: tl.constexpr, -FP8_TYPE: tl.constexpr, -scale, +REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): - - token_id = tl.program_id(0) - head_id = tl.program_id(1) - + """ + Layouts: + x: [num_pages * NUM_KV_HEAD, x_D0, x_D1] (page-major, row-major inside page) + output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) + + Behavior: + - token_id comes from pid0; head_id comes from pid1. + - Read loc[token_id] to get absolute position; only proceed at page end. + - Map token -> page via page_idx = (token_position // PAGE_SIZE). + - Read the whole page for this (page_idx, head_id), do reduction, + then write a single vector to output at (token_id, head_id, :). + """ + + # --- Program IDs --- + token_id = tl.program_id(0) # [0 .. num_tokens-1] + head_id = tl.program_id(1) # [0 .. NUM_KV_HEAD-1] + + # --- Trigger only at end-of-page token --- token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return + # --- Page indexing for x (page-major) --- + # page linear id across heads page_idx = token_position // PAGE_SIZE page_id = page_idx * NUM_KV_HEAD + head_id + # Base element offset into x for this (page_id, head_id) + # x is laid out as contiguous pages, each page is [x_D0, x_D1] x_offset = page_id * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + # 2D row-major addressing within the page + rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_offset + rows * x_D1 + cols - if FP8_TYPE == 1: + # Load the full page block. Assumes full tiles; add masks if needed. + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_block = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_block = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_block = tl.load(src_ptr).to(tl.float32) + page_block = tl.load(src_ptr) + # --- Reduction & write-out --- if DIM == 1: - if REDUCE_TYPE == 0: + # Reduce over rows (axis=0) -> per-column vector, length = x_D1 + if REDUCE_TYPE == 0: # Mean + # For better accuracy you may upcast: tl.sum(page_block.to(tl.float32), axis=0) reduce_vec = (tl.sum(page_block, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=0).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=0) + else: # L2Norm (NOT RMS) + s = tl.sum(page_block * page_block, axis=0).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + # output is token-major: [num_tokens, NUM_KV_HEAD, x_D1] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 dst_ptr = output + out_base + tl.arange(0, x_D1) tl.store(dst_ptr, reduce_vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: Reduce over cols (axis=1) -> per-row vector, length = x_D0 + if REDUCE_TYPE == 0: # Mean reduce_vec = (tl.sum(page_block, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max reduce_vec = tl.max(page_block, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min reduce_vec = tl.min(page_block, axis=1).to(tl.bfloat16) - else: - s = tl.sum(page_block * page_block, axis=1) + else: # L2Norm (NOT RMS) + s = tl.sum(page_block * page_block, axis=1).to(tl.float32) reduce_vec = tl.sqrt(s).to(tl.bfloat16) + + # output is token-major: [num_tokens, NUM_KV_HEAD, x_D0] out_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 dst_ptr = output + out_base + tl.arange(0, x_D0) tl.store(dst_ptr, reduce_vec) @@ -345,8 +429,9 @@ def reduce_pr( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -362,8 +447,9 @@ def reduce_pr( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) def _reduce_pr( @@ -374,8 +460,9 @@ def _reduce_pr( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -391,67 +478,92 @@ def _reduce_pr( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @triton.jit def reduce_rr_kernel( x, output, loc, -x_D0: tl.constexpr, -x_D1: tl.constexpr, +x_D0: tl.constexpr, # rows per token-page +x_D1: tl.constexpr, # cols per token-page NUM_KV_HEAD: tl.constexpr, PAGE_SIZE: tl.constexpr, -REDUCE_TYPE: tl.constexpr, -DIM: tl.constexpr, -FP8_TYPE: tl.constexpr, -scale, +REDUCE_TYPE: tl.constexpr, # 0: Mean, 1: Max, 2: Min, 3: L2Norm (not RMS) +DIM: tl.constexpr, # 1: reduce over rows -> len x_D1; 2: reduce over cols -> len x_D0 +QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 +scale, # float: 1.0 for bf16, kv_scale for fp8 +kv_scale_ptr, # pointer to per-token int8 scales (unused when QUANT_TYPE != 1) ): + """ + Layouts: + x: [num_tokens * NUM_KV_HEAD, x_D0, x_D1] (token-major) + output: [num_tokens * NUM_KV_HEAD, vec_len] (token-major; vec_len = x_D1 if DIM==1 else x_D0) + + Only the last token of each page performs the reduction and writes to output[token_id, head_id, :]. + """ - token_id = tl.program_id(0) - head_id = tl.program_id(1) + # program ids + token_id = tl.program_id(0) # 0..num_tokens-1 + head_id = tl.program_id(1) # 0..NUM_KV_HEAD-1 + + # trigger only at end-of-page token token_position = tl.load(loc + token_id) if (token_position + 1) % PAGE_SIZE != 0: return + # ---- read from x (token-major) ---- x_base = (token_id * NUM_KV_HEAD + head_id) * x_D0 * x_D1 - rows = tl.arange(0, x_D0)[:, None] - cols = tl.arange(0, x_D1)[None, :] + rows = tl.arange(0, x_D0)[:, None] # [x_D0, 1] + cols = tl.arange(0, x_D1)[None, :] # [1, x_D1] src_ptr = x + x_base + rows * x_D1 + cols - if FP8_TYPE == 1: + if QUANT_TYPE == 1: + # int8: load int8 values, dequant with per-row scale + raw = tl.load(src_ptr).to(tl.float32) + page_id = (token_position // PAGE_SIZE) * NUM_KV_HEAD + head_id + scale_offset = page_id * x_D0 + tl.arange(0, x_D0) + row_scales = tl.load(kv_scale_ptr + scale_offset).to(tl.float32) + page_blk = raw * row_scales[:, None] + elif QUANT_TYPE == 2: raw = tl.load(src_ptr) page_blk = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * scale - elif FP8_TYPE == 2: + elif QUANT_TYPE == 3: raw = tl.load(src_ptr) page_blk = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * scale else: - page_blk = tl.load(src_ptr).to(tl.float32) + page_blk = tl.load(src_ptr) # assumes full page; add masks if needed + # ---- reduce ---- if DIM == 1: - if REDUCE_TYPE == 0: + # over rows -> axis=0 -> vector len x_D1 + if REDUCE_TYPE == 0: # Mean + # For better accuracy you may upcast to fp32 before sum. vec = (tl.sum(page_blk, axis=0) / x_D0).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max vec = tl.max(page_blk, axis=0).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min vec = tl.min(page_blk, axis=0).to(tl.bfloat16) - else: + else: # L2Norm (NOT RMS) s = tl.sum(page_blk * page_blk, axis=0) vec = tl.sqrt(s).to(tl.bfloat16) + # ---- write to output (token-major) ---- out_base = (token_id * NUM_KV_HEAD + head_id) * x_D1 tl.store(output + out_base + tl.arange(0, x_D1), vec) else: - if REDUCE_TYPE == 0: + # DIM == 2: over cols -> axis=1 -> vector len x_D0 + if REDUCE_TYPE == 0: # Mean vec = (tl.sum(page_blk, axis=1) / x_D1).to(tl.bfloat16) - elif REDUCE_TYPE == 1: + elif REDUCE_TYPE == 1: # Max vec = tl.max(page_blk, axis=1).to(tl.bfloat16) - elif REDUCE_TYPE == 2: + elif REDUCE_TYPE == 2: # Min vec = tl.min(page_blk, axis=1).to(tl.bfloat16) - else: + else: # L2Norm (NOT RMS) s = tl.sum(page_blk * page_blk, axis=1) vec = tl.sqrt(s).to(tl.bfloat16) @@ -459,6 +571,7 @@ def reduce_rr_kernel( tl.store(output + out_base + tl.arange(0, x_D0), vec) + def reduce_rr( x: torch.Tensor, output: torch.Tensor, @@ -466,8 +579,9 @@ def reduce_rr( ctx: Context, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -483,8 +597,9 @@ def reduce_rr( PAGE_SIZE=ctx.page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) @@ -496,8 +611,9 @@ def _reduce_rr( page_size: int, dim: int, reduce_type: ReduceType, -fp8_type: int = 0, +quant_type: int = 0, scale: float = 1.0, +kv_scale_ptr=None, ): NNZ = loc.shape[0] @@ -513,6 +629,7 @@ def _reduce_rr( PAGE_SIZE=page_size, REDUCE_TYPE=reduce_type.value, DIM=dim, - FP8_TYPE=fp8_type, + QUANT_TYPE=quant_type, scale=scale, + kv_scale_ptr=kv_scale_ptr if kv_scale_ptr is not None else x, ) From b9eb71786eb8dd12150f02b4dcd15b053328b931 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 2 Mar 2026 06:33:57 +0000 Subject: [PATCH 05/24] adapt topk kernel from sglang to vortex --- csrc/topk.cu | 1029 +++++++++++++++++++++++++++------- examples/verify_algo.sh | 2 +- examples/verify_algo_fp8.sh | 1 - examples/verify_algo_int8.sh | 1 - 4 files changed, 827 insertions(+), 206 deletions(-) diff --git a/csrc/topk.cu b/csrc/topk.cu index 62d747eb..8a48aad5 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -1,203 +1,826 @@ -#include "register.h" -#include - - -template -__global__ void TopKOutput_F32_Kernel( -const float* __restrict__ score, -const int* __restrict__ dense_kv_indptr, -const int* __restrict__ sparse_kv_indptr, -const int* __restrict__ dense_kv_indices, -int* __restrict__ sparse_kv_indices, -const int topk_val, -const int page_reserved_bos, -const int page_reserved_eos) -{ - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const float* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; - - float key[ITEM_PER_THREAD]; - int val[ITEM_PER_THREAD]; - - using BLF = cub::BlockLoad; - using BLI = cub::BlockLoad; - using BSI = cub::BlockStore; - using Sort = cub::BlockRadixSort; - - __shared__ union { - typename BLF::TempStorage lf; - typename BLI::TempStorage li; - typename BSI::TempStorage si; - typename Sort::TempStorage sort; - } temp; - - BLF(temp.lf).Load(score_blk, key, nblk, -INFINITY); - __syncthreads(); - BLI(temp.li).Load(idx_blk, val, nblk, 0); - __syncthreads(); - - Sort(temp.sort).SortDescending(key, val); - __syncthreads(); - - const int valid_out = min(topk_val, nblk); - BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); -} - - -template -__global__ void TopKOutput_BF16_Kernel( -const __nv_bfloat16* __restrict__ score, -const int* __restrict__ dense_kv_indptr, -const int* __restrict__ sparse_kv_indptr, -const int* __restrict__ dense_kv_indices, -int* __restrict__ sparse_kv_indices, -const int topk_val, -const int page_reserved_bos, -const int page_reserved_eos) -{ - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const __nv_bfloat16* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; - - const __nv_bfloat16 ninf_bf16 = __float2bfloat16(-CUDART_INF_F); - - __nv_bfloat16 key_bf16[ITEM_PER_THREAD]; - float key[ITEM_PER_THREAD]; - int val[ITEM_PER_THREAD]; - - using BLF = cub::BlockLoad<__nv_bfloat16, NUM_THREADS, ITEM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE>; - using BLI = cub::BlockLoad; - using BSI = cub::BlockStore; - using Sort = cub::BlockRadixSort; - - __shared__ union { - typename BLF::TempStorage lf; - typename BLI::TempStorage li; - typename BSI::TempStorage si; - typename Sort::TempStorage sort; - } temp; - - BLF(temp.lf).Load(score_blk, key_bf16, nblk, ninf_bf16); - - #pragma unroll - for (int i = 0; i < ITEM_PER_THREAD; ++i){ - key[i] = __bfloat162float(key_bf16[i]); - } - __syncthreads(); - - BLI(temp.li).Load(idx_blk, val, nblk, 0); - __syncthreads(); - - Sort(temp.sort).SortDescending(key, val); - __syncthreads(); - - const int valid_out = min(topk_val, nblk); - BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); -} - - - -void topk_output( -const at::Tensor& x, -const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, -const at::Tensor& dense_kv_indices, -at::Tensor& sparse_kv_indices, -const int64_t eff_batch_size, -const int64_t topk_val, -const int64_t reserved_bos, -const int64_t reserved_eos, -const int64_t max_num_pages -){ - - - dim3 nblks(eff_batch_size); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (max_num_pages <= 128){ - TopKOutput_BF16_Kernel<128, 1><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 256){ - TopKOutput_BF16_Kernel<128, 2><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 512){ - TopKOutput_BF16_Kernel<128, 4><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 1024){ - TopKOutput_BF16_Kernel<256, 4><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 2048){ - TopKOutput_BF16_Kernel<256, 8><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else if (max_num_pages <= 4096){ - TopKOutput_BF16_Kernel<512, 8><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos - ); - } else { - TORCH_CHECK(false); - } - -} +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + constexpr int TopK = 2048; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex integration: BOS/EOS-aware segmented TopK with index remapping + // ====================================================================== + + template + __device__ __forceinline__ float vortex_to_float(T x); + + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + constexpr int VORTEX_MAX_TOPK = 2048; + + // Templated version of fast_topk_cuda_tl: + // - ScoreT: float or __nv_bfloat16 + // - target_k: runtime parameter (replaces compile-time TopK) + template + __device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) + { + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&vh_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast( + convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + // Wrapper kernel: one CUDA block per batch*head segment + template + __global__ __launch_bounds__(kThreadsPerBlock) + void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) + { + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex host entry point — same interface as topk_output in topk.cu + // ====================================================================== + void topk_output( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) + { + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else { + TORCH_CHECK(false, + "topk_output: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); + } \ No newline at end of file diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index d80f09a5..74877088 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=1 +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh index 7f266e5e..fd85dadc 100755 --- a/examples/verify_algo_fp8.sh +++ b/examples/verify_algo_fp8.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=3 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh index 4cf1366b..e57c63f5 100644 --- a/examples/verify_algo_int8.sh +++ b/examples/verify_algo_int8.sh @@ -1,6 +1,5 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=2 sparse_algos=( "block_sparse_attention" From ede862425998eaa1e5a3b449dec274798d0ffb1c Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 9 Mar 2026 05:22:43 +0000 Subject: [PATCH 06/24] add parameter to switch between two topk kernels (naive or sglang) --- csrc/register.cc | 1 + csrc/register.h | 12 + csrc/topk.cu | 1029 ++++++--------------------- examples/verify_algo.py | 15 +- examples/verify_algo_fp8.sh | 1 + examples/verify_algo_int8.sh | 1 + vortex_torch/indexer/context.py | 4 +- vortex_torch/indexer/output_func.py | 58 +- 8 files changed, 276 insertions(+), 845 deletions(-) diff --git a/csrc/register.cc b/csrc/register.cc index fd9d4eb2..532fcdfa 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -8,6 +8,7 @@ PYBIND11_MODULE(vortex_torch_C, m){ m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose); m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose); m.def("topk_output", &topk_output); + m.def("topk_output_sglang", &topk_output_sglang); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index 92499ed6..b81168bb 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -85,6 +85,18 @@ const int64_t reserved_eos, const int64_t max_seq_lengths ); +void topk_output_sglang( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_seq_lengths +); void sglang_plan_decode_fa3( const at::Tensor& cached_seq_lens, diff --git a/csrc/topk.cu b/csrc/topk.cu index 8a48aad5..3aa49b98 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -1,826 +1,203 @@ -/** - * @NOTE: This file is adapted from - * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py - * We: - * 1. adapt from tilelang to pure cuda - * 2. optimize the performance a little - * 3. fix the potential illegal memory access - */ - #include - #include - #include - #include - #include - #include - #include - #include - #include - - #include - #include - #include - - namespace { - - constexpr int TopK = 2048; - constexpr int kThreadsPerBlock = 1024; - - #ifdef USE_ROCM - // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a - // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. - #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES - constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); - #else - constexpr size_t kSmem = 48 * 1024; // bytes - #endif - #else - // Reduced from 128KB to 32KB to improve occupancy. - // Each radix pass needs at most ~TopK candidates in the threshold bin, - // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. - constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) - #endif - - struct FastTopKParams { - const float* __restrict__ input; // [B, input_stride] - const int32_t* __restrict__ row_starts; // [B] - int32_t* __restrict__ indices; // [B, TopK] - int32_t* __restrict__ lengths; // [B] - int64_t input_stride; - }; - - // when length <= TopK, we can directly write the indices - __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { - const auto tid = threadIdx.x; - for (int i = tid; i < TopK; i += kThreadsPerBlock) { - indice[i] = (i < length) ? i : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform( - const float* __restrict__ score, - int32_t length, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - dst_page_table[i] = (i < length) ? src_page_table[i] : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform_ragged( - const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; - } - } - - __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return static_cast(key >> 8); - } - - __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); - } - - __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { - // An optimized topk kernel copied from tilelang kernel - // We assume length > TopK here, or it will crash - int topk = TopK; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; - - auto& s_histogram = s_histogram_buf[0]; - // allocate for two rounds - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // stage 1: 8bit coarse histogram - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(input[idx + row_start]); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = input[idx + row_start]; - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - /// NOTE: (dark) fuse the histogram computation here - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // stage 2: refine with 8bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - // clip here to prevent overflow - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = input[idx + row_start]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[TopK - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - /// NOTE: (dark) fuse the histogram computation here - s_input_idx[r_idx ^ 1][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // topk - void topk_kernel(const FastTopKParams params) { - const auto& [input, row_starts, indices, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto indice = indices + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_cuda(score, indice, length); - } else { - return fast_topk_cuda_tl(score, indice, row_start, length); - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // decode - void topk_transform_decode_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride) { - const auto& [input, _1, _2, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = 0; - const auto length = lengths[bid]; - const auto src_page_entry = src_page_table + bid * src_stride; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill - void topk_transform_prefill_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride, - const int32_t* __restrict__ cu_seqlens_q, - const int64_t prefill_bs) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto length = lengths[bid]; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - - /// NOTE: prefill bs is usually small, we can just use a simple loop here - /// We ensure that last cu_seqlens is equal to number of blocks launched - __shared__ const int32_t* s_src_page_entry; - if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { - if (tid < prefill_bs) { - if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { - s_src_page_entry = src_page_table + tid * src_stride; - } - } - } else { - for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { - if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { - s_src_page_entry = src_page_table + i * src_stride; - } - } - } - __syncthreads(); - const auto src_page_entry = s_src_page_entry; - - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv - void topk_transform_prefill_ragged_kernel( - const FastTopKParams params, - int32_t* __restrict__ topk_indices_ragged, - const int32_t* __restrict__ topk_indices_offset) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto dst_indices_entry = topk_indices_ragged + bid * TopK; - const auto score = input + bid * input_stride; - const auto offset = topk_indices_offset[bid]; - - if (length <= TopK) { - return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_indices_entry[idx_0] = pos_0 + offset; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_indices_entry[idx_1] = pos_1 + offset; - } - } - - auto get_params( - const at::Tensor& score, - const at::Tensor& lengths, - std::optional row_starts_opt = std::nullopt, - std::optional indices_opt = std::nullopt) -> FastTopKParams { - const auto B = score.size(0); - TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); - if (row_starts_opt.has_value()) { - const auto& row_starts = row_starts_opt.value(); - TORCH_CHECK(row_starts.dim() == 1); - TORCH_CHECK(row_starts.size(0) == B); - } - TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); - TORCH_CHECK(lengths.size(0) == B); - int32_t* indices_data_ptr = nullptr; - if (indices_opt.has_value()) { - const auto& indices = indices_opt.value(); - TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); - TORCH_CHECK(indices.size(0) == B); - TORCH_CHECK(indices.size(1) == TopK); - indices_data_ptr = indices.data_ptr(); - } - - return FastTopKParams{ - .input = score.data_ptr(), - .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, - .indices = indices_data_ptr, - .lengths = lengths.data_ptr(), - .input_stride = score.stride(0), - }; - } - - template - void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { - #ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex integration: BOS/EOS-aware segmented TopK with index remapping - // ====================================================================== - - template - __device__ __forceinline__ float vortex_to_float(T x); - - template <> - __device__ __forceinline__ float vortex_to_float(float x) { return x; } - - template <> - __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { - return __bfloat162float(x); - } - - constexpr int VORTEX_MAX_TOPK = 2048; - - // Templated version of fast_topk_cuda_tl: - // - ScoreT: float or __nv_bfloat16 - // - target_k: runtime parameter (replaces compile-time TopK) - template - __device__ void fast_topk_vortex( - const ScoreT* __restrict__ input, - int* __restrict__ index, - int row_start, - int length, - int target_k) - { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; - - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // Stage 1: 8-bit coarse histogram - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); - ::atomicAdd(&vh_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast( - convert_to_uint8(vortex_to_float(input[idx + row_start]))); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // Stage 2: refine with 8-bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - // Wrapper kernel: one CUDA block per batch*head segment - template - __global__ __launch_bounds__(kThreadsPerBlock) - void TopKOutput_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos) - { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } - } - - } // namespace - - #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - - void fast_topk_interface( - const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(indices); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - CHECK_CUDA(lengths); - const auto params = get_params(score, lengths, row_starts_opt, indices); - const auto B = score.size(0); - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - setup_kernel_smem_once(); - topk_kernel<<>>(params); - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& dst_page_table, - const at::Tensor& src_page_table, - const at::Tensor& cu_seqlens_q, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(dst_page_table); - CHECK_CUDA(src_page_table); - CHECK_CUDA(cu_seqlens_q); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); - TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); - TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); - const auto prefill_bs = cu_seqlens_q.size(0) - 1; - TORCH_CHECK(dst_page_table.size(0) == B); - TORCH_CHECK(dst_page_table.size(1) == TopK); - TORCH_CHECK(src_page_table.size(0) == prefill_bs); - TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - const auto src_stride = src_page_table.stride(0); - - // dispatch to decode or prefill - // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel - // decode: row_starts_opt is null, invokes the decode kernel - // target verify: row_starts_opt is null, invokes the prefill kernel - const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; - if (is_decode) { - setup_kernel_smem_once(); - topk_transform_decode_kernel<<>>( - params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); - } else { - setup_kernel_smem_once(); - topk_transform_prefill_kernel<<>>( - params, - dst_page_table.data_ptr(), - src_page_table.data_ptr(), - src_stride, - cu_seqlens_q.data_ptr(), - prefill_bs); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_ragged_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& topk_indices_ragged, - const at::Tensor& topk_indices_offset, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(topk_indices_ragged); - CHECK_CUDA(topk_indices_offset); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); - TORCH_CHECK(topk_indices_offset.dim() == 1); - - TORCH_CHECK(topk_indices_ragged.size(0) == B); - TORCH_CHECK(topk_indices_ragged.size(1) == TopK); - TORCH_CHECK(topk_indices_offset.size(0) == B); - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - - setup_kernel_smem_once(); - topk_transform_prefill_ragged_kernel<<>>( - params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex host entry point — same interface as topk_output in topk.cu - // ====================================================================== - void topk_output( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages) - { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else { - TORCH_CHECK(false, - "topk_output: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); - } \ No newline at end of file +#include "register.h" +#include + + +template +__global__ void TopKOutput_F32_Kernel( +const float* __restrict__ score, +const int* __restrict__ dense_kv_indptr, +const int* __restrict__ sparse_kv_indptr, +const int* __restrict__ dense_kv_indices, +int* __restrict__ sparse_kv_indices, +const int topk_val, +const int page_reserved_bos, +const int page_reserved_eos) +{ + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const float* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; + + float key[ITEM_PER_THREAD]; + int val[ITEM_PER_THREAD]; + + using BLF = cub::BlockLoad; + using BLI = cub::BlockLoad; + using BSI = cub::BlockStore; + using Sort = cub::BlockRadixSort; + + __shared__ union { + typename BLF::TempStorage lf; + typename BLI::TempStorage li; + typename BSI::TempStorage si; + typename Sort::TempStorage sort; + } temp; + + BLF(temp.lf).Load(score_blk, key, nblk, -INFINITY); + __syncthreads(); + BLI(temp.li).Load(idx_blk, val, nblk, 0); + __syncthreads(); + + Sort(temp.sort).SortDescending(key, val); + __syncthreads(); + + const int valid_out = min(topk_val, nblk); + BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); +} + + +template +__global__ void TopKOutput_BF16_Kernel( +const __nv_bfloat16* __restrict__ score, +const int* __restrict__ dense_kv_indptr, +const int* __restrict__ sparse_kv_indptr, +const int* __restrict__ dense_kv_indices, +int* __restrict__ sparse_kv_indices, +const int topk_val, +const int page_reserved_bos, +const int page_reserved_eos) +{ + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const __nv_bfloat16* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + sparse_kv_indptr[bx] + page_reserved_bos; + + const __nv_bfloat16 ninf_bf16 = __float2bfloat16(-CUDART_INF_F); + + __nv_bfloat16 key_bf16[ITEM_PER_THREAD]; + float key[ITEM_PER_THREAD]; + int val[ITEM_PER_THREAD]; + + using BLF = cub::BlockLoad<__nv_bfloat16, NUM_THREADS, ITEM_PER_THREAD, cub::BLOCK_LOAD_WARP_TRANSPOSE>; + using BLI = cub::BlockLoad; + using BSI = cub::BlockStore; + using Sort = cub::BlockRadixSort; + + __shared__ union { + typename BLF::TempStorage lf; + typename BLI::TempStorage li; + typename BSI::TempStorage si; + typename Sort::TempStorage sort; + } temp; + + BLF(temp.lf).Load(score_blk, key_bf16, nblk, ninf_bf16); + + #pragma unroll + for (int i = 0; i < ITEM_PER_THREAD; ++i){ + key[i] = __bfloat162float(key_bf16[i]); + } + __syncthreads(); + + BLI(temp.li).Load(idx_blk, val, nblk, 0); + __syncthreads(); + + Sort(temp.sort).SortDescending(key, val); + __syncthreads(); + + const int valid_out = min(topk_val, nblk); + BSI(temp.si).Store(out_blk, /*per-thread regs*/ val, valid_out); +} + + + +void topk_output( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +){ + + + dim3 nblks(eff_batch_size); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + if (max_num_pages <= 128){ + TopKOutput_BF16_Kernel<128, 1><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 256){ + TopKOutput_BF16_Kernel<128, 2><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 512){ + TopKOutput_BF16_Kernel<128, 4><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 1024){ + TopKOutput_BF16_Kernel<256, 4><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 2048){ + TopKOutput_BF16_Kernel<256, 8><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else if (max_num_pages <= 4096){ + TopKOutput_BF16_Kernel<512, 8><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); + } else { + TORCH_CHECK(false); + } + +} \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 9958b7e3..1187aca2 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -56,12 +56,13 @@ def verify_algos( sparse_attention: bool = True, mem: float = 0.8, kv_cache_dtype: str = "auto", +topk_type: str = "naive", ): - llm = sgl.Engine(model_path=model_name, + llm = sgl.Engine(model_path=model_name, disable_cuda_graph=False, page_size=page_size, - vortex_topk_val=topk_val, + vortex_topk_val=topk_val, disable_overlap_schedule=True, attention_backend="flashinfer", enable_vortex_sparsity=sparse_attention, @@ -72,6 +73,7 @@ def verify_algos( vortex_max_seq_lens=12288, mem_fraction_static=mem, kv_cache_dtype=kv_cache_dtype, + vortex_topk_type=topk_type, ) with open("amc23.jsonl", "r", encoding="utf-8") as f: @@ -221,6 +223,14 @@ def parse_args(): choices=["auto", "fp8_e5m2", "fp8_e4m3", "int8"], help='KV cache dtype (default: "auto").', ) + + parser.add_argument( + "--topk-type", + type=str, + default="naive", + choices=["naive", "sglang"], + help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").', + ) return parser.parse_args() if __name__ == "__main__": @@ -235,6 +245,7 @@ def parse_args(): sparse_attention=not(args.full_attention), mem=args.mem, kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, ) print(summary) diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_fp8.sh index fd85dadc..c0b8814d 100755 --- a/examples/verify_algo_fp8.sh +++ b/examples/verify_algo_fp8.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash set -e +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh index e57c63f5..bf24c2d1 100644 --- a/examples/verify_algo_int8.sh +++ b/examples/verify_algo_int8.sh @@ -1,5 +1,6 @@ #!/usr/bin/env bash set -e +# export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 6d3c586c..d6da9c1a 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -22,7 +22,7 @@ class Context(ContextBase): # hardware / paging "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc - "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", + "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", # auxilary memory in graph "_aux_total_bytes", @@ -68,6 +68,7 @@ class Context(ContextBase): topk_val: int #: Top-K value used in pruning or selection. page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). + topk_type: str #: TopK kernel type: "naive" or "sglang". # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -144,6 +145,7 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.page_reserved_bos = sa.vortex_page_reserved_bos self.page_reserved_eos = sa.vortex_page_reserved_eos + self.topk_type = getattr(sa, "vortex_topk_type", "naive") self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 5df795b6..f7d0d9c2 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Callable, Optional from ..abs import vOp -from vortex_torch_C import topk_output +from vortex_torch_C import topk_output, topk_output_sglang from .context import Context from ..abs import vTensor, FORMAT @@ -75,13 +75,17 @@ class topK(vOp): """ # Dispatch by input format; only RAGGED is supported for now. - _impl_map: Dict[FORMAT, Callable] = { - FORMAT.RAGGED: topk_output, + _impl_map: Dict[FORMAT, Dict[str, Callable]] = { + FORMAT.RAGGED: { + "naive": topk_output, + "sglang": topk_output_sglang, + }, } def __init__(self): super().__init__() self.impl: Optional[Callable] = None + self.topk_type: str = "naive" # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: @@ -152,7 +156,13 @@ def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: f"{prefix}no implementation for x._format={x_fmt}. " f"Available: {list(self._impl_map.keys())}" ) - self.impl = self._impl_map[x_fmt] + self.topk_type = getattr(ctx, "topk_type", "naive") + impl_variants = self._impl_map[x_fmt] + assert self.topk_type in impl_variants, ( + f"{prefix}no topk implementation for topk_type='{self.topk_type}'. " + f"Available: {list(impl_variants.keys())}" + ) + self.impl = impl_variants[self.topk_type] # ---- optional sanity checks on `o` ---- # We only assert device consistency and leave exact (S_pack, D0, D1) @@ -220,16 +230,32 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso prefix = self._prefix() assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" - self.impl( - x, - ctx.dense_kv_indptr, - ctx.sparse_kv_indptr, - ctx.dense_kv_indices, - o, - ctx.batch_size * ctx.num_kv_heads, - ctx.topk_val, - ctx.page_reserved_bos, - ctx.page_reserved_eos, - ctx.max_num_pages_per_request, - ) + if self.topk_type == "sglang": + # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) + else: + # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) + self.impl( + x, + ctx.dense_kv_indptr, + ctx.dense_kv_indices, + ctx.sparse_kv_indptr, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) return o From edbf7899b34bb70724839b466c208750c2f6df94 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 9 Mar 2026 05:30:20 +0000 Subject: [PATCH 07/24] add parameter to switch between two topk kernels (naive or sglang) --- examples/verify_algo.sh | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 74877088..73ac2f43 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -19,6 +19,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-val 30 \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/setup.py b/setup.py index f35ddae6..99c6529b 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ 'csrc/register.cc', 'csrc/utils_sglang.cu', 'csrc/topk.cu', + 'csrc/topk_sglang.cu', ], include_dirs=['csrc'], extra_compile_args={ From 87d7664c8cefcbf23e835ece5df549e53863772a Mon Sep 17 00:00:00 2001 From: UED Date: Wed, 18 Mar 2026 23:40:58 +0000 Subject: [PATCH 08/24] add aim24 --- csrc/topk_sglang.cu | 826 +++++++++++++++++++++++++++++++++++++++ examples/verify_aim24.py | 111 ++++++ 2 files changed, 937 insertions(+) create mode 100644 csrc/topk_sglang.cu create mode 100644 examples/verify_aim24.py diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu new file mode 100644 index 00000000..314f0fde --- /dev/null +++ b/csrc/topk_sglang.cu @@ -0,0 +1,826 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + constexpr int TopK = 2048; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex integration: BOS/EOS-aware segmented TopK with index remapping + // ====================================================================== + + template + __device__ __forceinline__ float vortex_to_float(T x); + + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + constexpr int VORTEX_MAX_TOPK = 2048; + + // Templated version of fast_topk_cuda_tl: + // - ScoreT: float or __nv_bfloat16 + // - target_k: runtime parameter (replaces compile-time TopK) + template + __device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) + { + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&vh_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast( + convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + // Wrapper kernel: one CUDA block per batch*head segment + template + __global__ __launch_bounds__(kThreadsPerBlock) + void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) + { + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + // ====================================================================== + // Vortex host entry point — same interface as topk_output in topk.cu + // ====================================================================== + void topk_output_sglang( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) + { + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else { + TORCH_CHECK(false, + "topk_output: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); + } \ No newline at end of file diff --git a/examples/verify_aim24.py b/examples/verify_aim24.py new file mode 100644 index 00000000..51526804 --- /dev/null +++ b/examples/verify_aim24.py @@ -0,0 +1,111 @@ +import json +import sys +sys.path.append("../") +import python.sglang as sgl +from transformers import AutoTokenizer +import os +from tqdm import tqdm +import time +import torch +os.environ["TOKENIZERS_PARALLELISM"] = "false" +MATH_QUERY_TEMPLATE = """ +Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. + +{Question} +""".strip() + +from datasets import load_dataset, Dataset, concatenate_datasets +def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial: int = 1, rank: int = 0, world_size: int = 1): + requests = [] + + # Step 1: Expand dataset trial times + if trial > 1: + dataset = Dataset.from_dict(dataset.to_dict().copy())  # ensure copy + datasets = [dataset] * trial + dataset = concatenate_datasets(datasets) + + total = len(dataset) + + # Step 2: Partition across ranks + per_proc = total // world_size + remainder = total % world_size + start = rank * per_proc + min(rank, remainder) + end = start + per_proc + (1 if rank < remainder else 0) + subset = dataset.select(list(range(start, end))) + + # Step 3: Format requests + for data in dataset: + conversations = [ + {"role": "user", "content": data_format.format(Question=data[field_name])} + ] + data["conversations"] = conversations + requests.append(data) + + return requests + + + + + + + +def main(): + model_name = "Qwen/Qwen3-0.6B" + llm = sgl.Engine(model_path=model_name, + disable_cuda_graph=False, + page_size=16, + vortex_num_selected_pages=29, + disable_overlap_schedule=True, + attention_backend="flashinfer", + enable_vortex_sparsity=True, + vortex_page_reserved_bos=1, + vortex_page_reserved_eos=2, + vortex_layers_skip=list(range(1)), + mem_fraction_static=0.9, + vortex_cg=True, + vortex_graph=True, + vortex_module_name="block_sparse_attention", + vortex_max_seq_lens=20480 + ) + + dataset = load_dataset("HuggingFaceH4/aime_2024", split="train") + + requests = generate_requests(dataset, "problem", MATH_QUERY_TEMPLATE) + + + + texts = [ + x["conversations"] for x in requests + ] + + tokenizer = AutoTokenizer.from_pretrained(model_name) + prompts = [ + tokenizer.apply_chat_template( + text, + tokenize=False, + add_generation_prompt=True, + enable_thinking=True + ) for text in texts + ] * 8 + + sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 16384} + total_tokens = 0 + total_time = 0.0 + start = time.perf_counter() + o = llm.generate(prompts, sampling_params) + elapsed = time.perf_counter() - start + total_time += elapsed + e2e_time = 0 + with open(f"0.6B_VTX_CG_TP1_16K.jsonl", "w", encoding="utf-8") as f: + for item in o: + total_tokens += item["meta_info"]["completion_tokens"] + e2e_time = max(e2e_time, item["meta_info"]["e2e_latency"]) + json.dump(item, f, ensure_ascii=False) + f.write("\n") + + meta_data = {"e2e_time": e2e_time, "total_time": total_time, "total_tokens": total_tokens, "throughput": total_tokens / total_time} + json.dump(meta_data, f, ensure_ascii=False) + f.write("\n") + +if __name__ == "__main__": + main() \ No newline at end of file From 66237d748ee2c69977c208c5a64faa2d382ffd79 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 24 Mar 2026 18:22:40 +0000 Subject: [PATCH 09/24] Implement sparse prefill with topk on a new ragged only warpper --- examples/verify_aim24.py | 5 ---- examples/verify_algo.sh | 2 +- examples/verify_algo_int8.sh | 25 ------------------- ...erify_algo_fp8.sh => verify_algo_quant.sh} | 18 +++++++++++-- vortex_torch/cache/context.py | 4 +++ vortex_torch/flow/flow.py | 1 + 6 files changed, 22 insertions(+), 33 deletions(-) delete mode 100644 examples/verify_algo_int8.sh rename examples/{verify_algo_fp8.sh => verify_algo_quant.sh} (55%) mode change 100755 => 100644 diff --git a/examples/verify_aim24.py b/examples/verify_aim24.py index 51526804..9e54a967 100644 --- a/examples/verify_aim24.py +++ b/examples/verify_aim24.py @@ -44,11 +44,6 @@ def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial return requests - - - - - def main(): model_name = "Qwen/Qwen3-0.6B" llm = sgl.Engine(model_path=model_name, diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 73ac2f43..8416e544 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -# export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=1 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_int8.sh b/examples/verify_algo_int8.sh deleted file mode 100644 index bf24c2d1..00000000 --- a/examples/verify_algo_int8.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/usr/bin/env bash -set -e -# export CUDA_VISIBLE_DEVICES=0 - -sparse_algos=( - "block_sparse_attention" -) - -RESULTS_DIR="results" -mkdir -p "${RESULTS_DIR}" -TIMESTAMP=$(date +%Y%m%d_%H%M%S) - - for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --kv-cache-dtype int8 \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done \ No newline at end of file diff --git a/examples/verify_algo_fp8.sh b/examples/verify_algo_quant.sh old mode 100755 new mode 100644 similarity index 55% rename from examples/verify_algo_fp8.sh rename to examples/verify_algo_quant.sh index c0b8814d..c344474a --- a/examples/verify_algo_fp8.sh +++ b/examples/verify_algo_quant.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -# export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=0 sparse_algos=( "block_sparse_attention" @@ -11,6 +11,20 @@ mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_int8_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype int8" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --kv-cache-dtype int8 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done + + for algo in "${sparse_algos[@]}"; do OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log" echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3" echo ">>> Saving results to ${OUTFILE}" @@ -22,4 +36,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --kv-cache-dtype fp8_e4m3 \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" - done + done \ No newline at end of file diff --git a/vortex_torch/cache/context.py b/vortex_torch/cache/context.py index 0e7171cc..3cdf0953 100644 --- a/vortex_torch/cache/context.py +++ b/vortex_torch/cache/context.py @@ -24,9 +24,11 @@ class Context(ContextBase): # Quantization: quant_type (0=none, 1=int8, 2=e4m3, 3=e5m2), # kv_scale (per-tensor fp8 scale), kv_scale_ptr (per-token int8 scale tensor) + # fp8_type: 0=none, 1=e4m3, 2=e5m2 (encoding for Triton kernels) "quant_type", "kv_scale", "kv_scale_ptr", + "fp8_type", ) @@ -49,6 +51,8 @@ def __init__(self) -> None: object.__setattr__(self, name, 1.0) # identity scale for bf16 elif name == "kv_scale_ptr": object.__setattr__(self, name, None) # per-token scale tensor (int8 only) + elif name == "fp8_type": + object.__setattr__(self, name, 0) # 0 = none (bf16 default) else: object.__setattr__(self, name, UNSET) diff --git a/vortex_torch/flow/flow.py b/vortex_torch/flow/flow.py index 7efc80e9..7da5c72c 100644 --- a/vortex_torch/flow/flow.py +++ b/vortex_torch/flow/flow.py @@ -431,6 +431,7 @@ def run_indexer_virtual(self, group_size: int, page_size: int, head_dim: int): ctx.page_size = page_size ctx.max_num_pages = 0 ctx.max_num_pages_per_request = 0 + ctx.topk_type = "naive" device = "cuda" dtype = torch.bfloat16 From 9a73a8cccf635a055564d5f5fb155d152854748c Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 29 Mar 2026 05:01:09 +0000 Subject: [PATCH 10/24] fix on the ragged warpper, using single ragged warpper on concated rags and pages; fix on the previous quantization implementaion, with lanuch_graph dtype set to the quant type --- examples/verify_algo.py | 17 +++++++++++------ examples/verify_algo.sh | 7 +++++-- third_party/sglang | 2 +- vortex_torch/flow/__init__.py | 4 +++- vortex_torch/indexer/output_func.py | 4 ++-- 5 files changed, 22 insertions(+), 12 deletions(-) diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 1187aca2..91f92e76 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -142,12 +142,17 @@ def verify_algos( if sparse_attention: llm_cfg = AutoConfig.from_pretrained(model_name) - flow = vortex_torch.flow.build_vflow(vortex_module_name) - memory_access_runtime = flow.run_indexer_virtual( - group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads, - page_size=page_size, - head_dim=llm_cfg.head_dim, - ) + flow = vortex_torch.flow.build_vflow(vortex_module_name) + try: + memory_access_runtime = flow.run_indexer_virtual( + group_size=llm_cfg.num_attention_heads // llm_cfg.num_key_value_heads, + page_size=page_size, + head_dim=llm_cfg.head_dim, + ) + except Exception: + # External algorithms (nsa, fsa, flash_moba) override run_indexer_virtual + # to return 0 since their vendored kernels don't participate in vortex profiling + memory_access_runtime = 0.0 else: memory_access_runtime = 0.0 diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 8416e544..cc174b41 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,9 +1,12 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=1 +export CUDA_VISIBLE_DEVICES=7 sparse_algos=( "block_sparse_attention" + "nsa" + "fsa" + "flash_moba" ) RESULTS_DIR="results" @@ -19,7 +22,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-val 30 \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ + --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/third_party/sglang b/third_party/sglang index 7105719f..20e4c29d 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 7105719f0a2ac464ee7ffdc0a899fa6a656656a2 +Subproject commit 20e4c29d206046d6b4eb3b57cc26fd20bf9c519b diff --git a/vortex_torch/flow/__init__.py b/vortex_torch/flow/__init__.py index b2fcadc2..bb60b895 100644 --- a/vortex_torch/flow/__init__.py +++ b/vortex_torch/flow/__init__.py @@ -34,9 +34,11 @@ class BlockSparseAttention(vFlow): from .registry import register from .loader import build_vflow from . import algorithms +from . import external_algorithms __all__ = [ "vFlow", "register", "build_vflow", - "algorithms" + "algorithms", + "external_algorithms", ] \ No newline at end of file diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index f7d0d9c2..8859d61a 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -245,12 +245,12 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.max_num_pages_per_request, ) else: - # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) + # topk_output (naive): (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) self.impl( x, ctx.dense_kv_indptr, - ctx.dense_kv_indices, ctx.sparse_kv_indptr, + ctx.dense_kv_indices, o, ctx.batch_size * ctx.num_kv_heads, ctx.topk_val, From a8fd32854d62ec677a3467c6f29557bf9540a2cd Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 30 Mar 2026 02:35:40 +0000 Subject: [PATCH 11/24] Sparse attention kernel apdation with full attention kernels, include (naive sparse attention, flash sparse attention, flashmoba) --- examples/verify_algo.sh | 5 +- examples/verify_algo_topk.sh | 44 + examples/verify_sparse_backends.sh | 27 + vortex_torch/attention_backend/__init__.py | 3 + .../attention_backend/flashmoba/__init__.py | 13 + .../flashmoba/flash_moba_interface.py | 730 ++++++ .../flashmoba/triton_mean_pool.py | 158 ++ .../fsa/FSA_topk_sparse_attention.py | 2040 +++++++++++++++++ .../attention_backend/fsa/__init__.py | 9 + .../attention_backend/nsa/__init__.py | 9 + .../nsa/topk_sparse_attention.py | 1280 +++++++++++ vortex_torch/flow/external_algorithms.py | 76 + vortex_torch/kernels/__init__.py | 0 vortex_torch/kernels/fsa/__init__.py | 5 + .../kernels/fsa/fused_score_kernels.py | 300 +++ vortex_torch/kernels/nsa/__init__.py | 24 + .../kernels/nsa/compressed_attention.py | 1317 +++++++++++ vortex_torch/kernels/nsa/flash_attention.py | 886 +++++++ vortex_torch/kernels/nsa/utils.py | 50 + vortex_torch/kernels/nsa/weighted_pool.py | 341 +++ 20 files changed, 7313 insertions(+), 4 deletions(-) create mode 100644 examples/verify_algo_topk.sh create mode 100755 examples/verify_sparse_backends.sh create mode 100644 vortex_torch/attention_backend/__init__.py create mode 100644 vortex_torch/attention_backend/flashmoba/__init__.py create mode 100644 vortex_torch/attention_backend/flashmoba/flash_moba_interface.py create mode 100644 vortex_torch/attention_backend/flashmoba/triton_mean_pool.py create mode 100644 vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py create mode 100644 vortex_torch/attention_backend/fsa/__init__.py create mode 100644 vortex_torch/attention_backend/nsa/__init__.py create mode 100644 vortex_torch/attention_backend/nsa/topk_sparse_attention.py create mode 100644 vortex_torch/flow/external_algorithms.py create mode 100644 vortex_torch/kernels/__init__.py create mode 100644 vortex_torch/kernels/fsa/__init__.py create mode 100644 vortex_torch/kernels/fsa/fused_score_kernels.py create mode 100644 vortex_torch/kernels/nsa/__init__.py create mode 100644 vortex_torch/kernels/nsa/compressed_attention.py create mode 100644 vortex_torch/kernels/nsa/flash_attention.py create mode 100644 vortex_torch/kernels/nsa/utils.py create mode 100644 vortex_torch/kernels/nsa/weighted_pool.py diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index cc174b41..0dcbe9fc 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,12 +1,9 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=7 +export CUDA_VISIBLE_DEVICES=5 sparse_algos=( "block_sparse_attention" - "nsa" - "fsa" - "flash_moba" ) RESULTS_DIR="results" diff --git a/examples/verify_algo_topk.sh b/examples/verify_algo_topk.sh new file mode 100644 index 00000000..6b2744ae --- /dev/null +++ b/examples/verify_algo_topk.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=5 + +sparse_algos=( + "block_sparse_attention" +) + +RESULTS_DIR="results" +REPEAT_COUNT="${REPEAT_COUNT:-3}" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +for repeat_idx in $(seq 1 "${REPEAT_COUNT}"); do + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_naive_${TIMESTAMP}_run${repeat_idx}.log" + echo ">>> Run ${repeat_idx}/${REPEAT_COUNT}: verify_algo.py with --vortex-module-name ${algo} --topk-type naive" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +for repeat_idx in $(seq 1 "${REPEAT_COUNT}"); do + for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_sglang_${TIMESTAMP}_run${repeat_idx}.log" + echo ">>> Run ${repeat_idx}/${REPEAT_COUNT}: verify_algo.py with --vortex-module-name ${algo} --topk-type sglang" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done \ No newline at end of file diff --git a/examples/verify_sparse_backends.sh b/examples/verify_sparse_backends.sh new file mode 100755 index 00000000..81b3562d --- /dev/null +++ b/examples/verify_sparse_backends.sh @@ -0,0 +1,27 @@ +#!/usr/bin/env bash +set -e +export CUDA_VISIBLE_DEVICES=5 + +sparse_algos=( + "nsa" + "fsa" + "flash_moba" +) + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/${algo}_bf16_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done diff --git a/vortex_torch/attention_backend/__init__.py b/vortex_torch/attention_backend/__init__.py new file mode 100644 index 00000000..9ca7855b --- /dev/null +++ b/vortex_torch/attention_backend/__init__.py @@ -0,0 +1,3 @@ +# Vendored sparse attention backends for Vortex forward_extend. +# NSA and FSA are pure Triton kernels. +# FlashMoBA requires flash_moba_cuda C++ extension (pip install flash_moba). diff --git a/vortex_torch/attention_backend/flashmoba/__init__.py b/vortex_torch/attention_backend/flashmoba/__init__.py new file mode 100644 index 00000000..aa912b91 --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/__init__.py @@ -0,0 +1,13 @@ +from .flash_moba_interface import ( + flash_moba_varlen_func, + flash_moba_attn_varlen_func, + flash_topk_varlen_func, + decide_lg_block_m, +) + +__all__ = [ + "flash_moba_varlen_func", + "flash_moba_attn_varlen_func", + "flash_topk_varlen_func", + "decide_lg_block_m", +] diff --git a/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py b/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py new file mode 100644 index 00000000..c196c21d --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/flash_moba_interface.py @@ -0,0 +1,730 @@ +from typing import Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +import os + +try: + import flash_moba_cuda as flash_moba_gpu +except ImportError: + flash_moba_gpu = None +from .triton_mean_pool import flash_topk_mean_pool + +########################################################################################################################## +# Helper functions +########################################################################################################################## + +def round_multiple(x: int, m: int) -> int: + """Round x up to the nearest multiple of m.""" + return ((x + m - 1) // m) * m + +########################################################################################################################## + +def decide_lg_block_m(top_k: int, chunk_size: int, seqlen: int, causal: bool = False) -> int: + sparsity = 0.0 + budget = top_k * chunk_size + if causal: + density = (2*(budget * seqlen) - budget**2) / (seqlen**2) + else: + density = budget / seqlen + + sparsity = 1 - density + + if sparsity <= 0.5: + lg_block_m = 128 + elif sparsity <= 0.7: + lg_block_m = 256 + elif sparsity <= 0.8: + lg_block_m = 512 + elif sparsity <= 0.9: + lg_block_m = 768 + else: + lg_block_m = 1024 + + # [Optimization] Hardware-aware cap for A6000/3090/4090 to avoid Shared Memory OOM + if torch.cuda.is_available(): + major, minor = torch.cuda.get_device_capability() + # sm86 (A6000, 3090) and sm89 (4090, L40) have smaller shared memory than A100 (sm80) + if major == 8 and minor > 0: + lg_block_m = min(lg_block_m, 512) + + return lg_block_m + +########################################################################################################################## + +# torch.compile() support is only enabled for pytorch >= 2.4 +# The reason for this is that we are using the new custom_op and register_fake +# APIs, which support inplace modification of inputs in the function itself +if torch.__version__ >= "2.4.0": + _torch_custom_op_wrapper = torch.library.custom_op + _torch_register_fake_wrapper = torch.library.register_fake +else: + def noop_custom_op_wrapper(name, fn=None, /, *, mutates_args, device_types=None, schema=None): + def wrap(func): + return func + if fn is None: + return wrap + return fn + def noop_register_fake_wrapper(op, fn=None, /, *, lib=None, _stacklevel=1): + def wrap(func): + return func + if fn is None: + return wrap + return fn + _torch_custom_op_wrapper = noop_custom_op_wrapper + _torch_register_fake_wrapper = noop_register_fake_wrapper + + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + + +########################################################################################################################## +# Custom ops +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_moba_fused_topk", mutates_args=(), device_types="cuda") +def _moba_fused_topk( + q: torch.Tensor, + km: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_seqlens_km: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_topk: int, + moba_chunk_size: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, km = [maybe_contiguous(x) for x in (q, km)] + + col_offsets, col_nnz, indices, _, _ = flash_moba_gpu.moba_fused_topk( + q, + km, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_km, + max_seqlen_q, + max_seqlen_k, + moba_topk, + moba_chunk_size, + causal, + ) + return col_offsets, col_nnz, indices + +@_torch_register_fake_wrapper("flash_moba::_moba_fused_topk") +def _moba_fused_topk_fake( + q: torch.Tensor, + km: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_seqlens_km: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_topk: int, + moba_chunk_size: int, + causal: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q, km = [maybe_contiguous(x) for x in (q, km)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + max_lg_col_num = (max_seqlen_k + moba_chunk_size - 1) // moba_chunk_size + + col_offsets = torch.empty((batch_size, num_heads, max_lg_col_num), device=q.device, dtype=torch.int64) + col_nnz = torch.empty((batch_size, num_heads, max_lg_col_num), device=q.device, dtype=torch.int32) + indices = torch.empty((total_q * num_heads * moba_topk), device=q.device, dtype=torch.int32) + + return col_offsets, col_nnz, indices + +if torch.__version__ >= "2.4.0": + _wrapped_moba_fused_topk = torch.ops.flash_moba._moba_fused_topk +else: + _wrapped_moba_fused_topk = _moba_fused_topk + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_varlen_sort", mutates_args=(), device_types="cuda") +def _varlen_sort( + col_offsets: torch.Tensor, + col_nnz: torch.Tensor, + indices: torch.Tensor, +) -> torch.Tensor: + col_offset_ends = col_offsets.view(-1) + col_nnz.view(-1) + return flash_moba_gpu.varlen_sort( + col_offsets.view(-1), col_offset_ends, indices + ) + +@_torch_register_fake_wrapper("flash_moba::_varlen_sort") +def _varlen_sort_fake( + col_offsets: torch.Tensor, + col_nnz: torch.Tensor, + indices: torch.Tensor, +) -> torch.Tensor: + # varlen_sort is out-of-place + col_offset_ends = col_offsets.view(-1) + col_nnz.view(-1) + return torch.empty_like(indices) + +if torch.__version__ >= "2.4.0": + _wrapped_varlen_sort = torch.ops.flash_moba._varlen_sort +else: + _wrapped_varlen_sort = _varlen_sort + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_flash_moba_attn_varlen_forward", mutates_args=(), device_types="cuda") +def _flash_moba_attn_varlen_forward( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + moba_col_offsets = maybe_contiguous(moba_col_offsets) + moba_col_nnz = maybe_contiguous(moba_col_nnz) + moba_row_indices = maybe_contiguous(moba_row_indices) + + out, softmax_lse, S_dmask, rng_state = flash_moba_gpu.moba_varlen_fwd( + q, + k, + v, + None, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + softcap, + return_softmax, + lg_block_m, + lg_block_n, + None, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, softmax_lse, S_dmask, rng_state + +@_torch_register_fake_wrapper("flash_moba::_flash_moba_attn_varlen_forward") +def _flash_moba_attn_varlen_forward_fake( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float = 0.0, + alibi_slopes: Optional[torch.Tensor] = None, + return_softmax: bool = False, + leftpad_k: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + out = torch.empty_like(q) + softmax_lse = torch.empty((num_heads, total_q), dtype=torch.float32, device=q.device, layout=q.layout) + p = torch.empty((0,), dtype=q.dtype, device=q.device, layout=q.layout) + seqlen_q_rounded = round_multiple(max_seqlen_q, 128) + seqlen_k_rounded = round_multiple(max_seqlen_k, 128) + if return_softmax: + p = torch.empty((batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded), dtype=q.dtype, device=q.device, layout=q.layout) + rng_state = torch.empty((2,), dtype=torch.int64, device=q.device) + return out, softmax_lse, p, rng_state + +if torch.__version__ >= "2.4.0": + _wrapped_flash_moba_attn_varlen_forward = torch.ops.flash_moba._flash_moba_attn_varlen_forward +else: + _wrapped_flash_moba_attn_varlen_forward = _flash_moba_attn_varlen_forward + +########################################################################################################################## + +@_torch_custom_op_wrapper("flash_moba::_flash_moba_attn_varlen_backward", mutates_args=("dq", "dk", "dv"), device_types="cuda") +def _flash_moba_attn_varlen_backward( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> torch.Tensor: + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_moba_gpu.moba_varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + zero_tensors, + causal, + softcap, + deterministic, + lg_block_m, + lg_block_n, + None, + rng_state, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return softmax_d + +@_torch_register_fake_wrapper("flash_moba::_flash_moba_attn_varlen_backward") +def _flash_moba_attn_varlen_backward_fake( + dout: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + out: torch.Tensor, + softmax_lse: torch.Tensor, + dq: Optional[torch.Tensor], + dk: Optional[torch.Tensor], + dv: Optional[torch.Tensor], + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + # MOBA sparse pattern parameters + moba_col_offsets: torch.Tensor, + moba_col_nnz: torch.Tensor, + moba_row_indices: torch.Tensor, + lg_block_m: int, + lg_block_n: int, + dropout_p: float, + softmax_scale: float, + causal: bool, + softcap: float, + alibi_slopes: Optional[torch.Tensor], + deterministic: bool, + rng_state: Optional[torch.Tensor] = None, + zero_tensors: bool = False, +) -> torch.Tensor: + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + batch_size = cu_seqlens_q.numel() - 1 + total_q, num_heads, _ = q.shape + + if dq is None: + dq = torch.empty_like(q) + if dk is None: + dk = torch.empty_like(k) + if dv is None: + dv = torch.empty_like(v) + softmax_d = torch.empty((num_heads, total_q + 128 * batch_size), device=q.device, dtype=torch.float32) + + return softmax_d + +if torch.__version__ >= "2.4.0": + _wrapped_flash_moba_attn_varlen_backward = torch.ops.flash_moba._flash_moba_attn_varlen_backward +else: + _wrapped_flash_moba_attn_varlen_backward = _flash_moba_attn_varlen_backward + +########################################################################################################################## + +class FlashMobaAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + deterministic, + return_softmax, + is_grad_enabled, + ): + is_grad = is_grad_enabled and any( + x.requires_grad for x in [q, k, v] + ) + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + v = torch.nn.functional.pad(v, [0, 8 - head_size_og % 8]) + out_padded, softmax_lse, S_dmask, rng_state = _wrapped_flash_moba_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal=causal, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + if is_grad: + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, + moba_col_offsets, moba_col_nnz, moba_row_indices + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + ctx.lg_block_m = lg_block_m + ctx.lg_block_n = lg_block_n + + out = out_padded[..., :head_size_og] + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state, moba_col_offsets, moba_col_nnz, moba_row_indices = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + head_size_og = dout.size(2) + dout_padded = dout + if head_size_og % 8 != 0: + dout_padded = torch.nn.functional.pad(dout, [0, 8 - head_size_og % 8]) + _wrapped_flash_moba_attn_varlen_backward( + dout_padded, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + ctx.lg_block_m, + ctx.lg_block_n, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None + +########################################################################################################################## + +def flash_topk_varlen_func( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_topk, + moba_chunk_size, + causal=False, +): + """ + Computes the top-k indices for Mixture-of-Blocks Attention (MOBA). + This function handles variable length sequences. + + Args: + q (torch.Tensor): Query tensor of shape (total_q, num_heads, head_size). + k (torch.Tensor): Key tensor of shape (total_k, num_heads, head_size). + cu_seqlens_q (torch.Tensor): Cumulative sequence lengths for queries, shape (batch_size + 1,). + cu_seqlens_k (torch.Tensor): Cumulative sequence lengths for keys, shape (batch_size + 1,). + max_seqlen_q (int): Maximum sequence length for queries. + max_seqlen_k (int): Maximum sequence length for keys. + moba_topk (int): The number of top-k elements to select. + moba_chunk_size (int): The chunk size for MOBA. + causal (bool): Whether to apply causal masking. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: A tuple containing: + - col_offsets (torch.Tensor): Column offsets for the sparse matrix. + - col_nnz (torch.Tensor): Number of non-zero elements per column block. + - indices (torch.Tensor): The top-k indices. + """ + head_size_og = q.size(2) + if head_size_og % 8 != 0: + q = torch.nn.functional.pad(q, [0, 8 - head_size_og % 8]) + k = torch.nn.functional.pad(k, [0, 8 - head_size_og % 8]) + + km, cu_seqlens_km, _ = flash_topk_mean_pool(k, cu_seqlens_k, max_seqlen_k, moba_chunk_size) + + col_offsets, col_nnz, indices = _wrapped_moba_fused_topk( + q, + km, + cu_seqlens_q, + cu_seqlens_k, + cu_seqlens_km, + max_seqlen_q, + max_seqlen_k, + moba_topk, + moba_chunk_size, + causal=causal + ) + + indices = _wrapped_varlen_sort( + col_offsets, col_nnz, indices + ) + + return col_offsets, col_nnz, indices + +########################################################################################################################## + +def flash_moba_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m=64, + lg_block_n=64, + dropout_p=0.0, + softmax_scale=None, + causal=False, + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + moba_col_offsets: Optional[torch.Tensor]. Column offsets for MOBA sparse pattern. + Shape: (batch_size, num_heads, max_lg_col_num), dtype: int64 + moba_col_nnz: Optional[torch.Tensor]. Non-zero counts per column for MOBA sparse pattern. + Shape: (batch_size, num_heads, max_lg_col_num), dtype: int32 + moba_row_indices: Optional[torch.Tensor]. Row indices for MOBA sparse pattern (flattened). + dtype: int32 + lg_block_m: int. Logical block size in M dimension (query). Default: 64 + lg_block_n: int. Logical block size in N dimension (key). Default: 64 + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashMobaAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_col_offsets, + moba_col_nnz, + moba_row_indices, + lg_block_m, + lg_block_n, + dropout_p, + softmax_scale, + causal, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + torch.is_grad_enabled(), + ) + +########################################################################################################################## + +def flash_moba_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + moba_chunk_size, + moba_topk, + causal=True, +): + + col_offsets, col_nnz, indices = flash_topk_varlen_func( + q, + k, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + # MOBA sparse pattern parameters + moba_topk, + moba_chunk_size, + causal=causal, + ) + + lg_block_m = decide_lg_block_m(moba_topk, moba_chunk_size, max_seqlen_k, causal) + + return flash_moba_attn_varlen_func( + q, k, v, + cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, + col_offsets, + col_nnz, + indices, + lg_block_m, + moba_chunk_size, + dropout_p=0.0, + causal=causal, + ) diff --git a/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py b/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py new file mode 100644 index 00000000..6fbd59f1 --- /dev/null +++ b/vortex_torch/attention_backend/flashmoba/triton_mean_pool.py @@ -0,0 +1,158 @@ +# Copyright (c) 2025, FlashMoBA Team. +import torch +import torch.nn.functional as F + +import triton +import triton.language as tl + + +# `triton.jit`'ed functions can be auto-tuned by using the `triton.autotune` decorator, which consumes: +# - A list of `triton.Config` objects that define different configurations of +# meta-parameters (e.g., `BLOCK_SIZE_M`) and compilation options (e.g., `num_warps`) to try +# - An auto-tuning *key* whose change in values will trigger evaluation of all the +# provided configs +@triton.autotune( + configs=[ + # triton.Config({'kBlockN': 16}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 32}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 64}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 64}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 64}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 64}, num_warps=8, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=2, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=4, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=4, num_stages=4), + triton.Config({'kBlockN': 128}, num_warps=8, num_stages=3), + triton.Config({'kBlockN': 128}, num_warps=8, num_stages=4), + # triton.Config({'kBlockN': 256}, num_warps=4, num_stages=3), + # triton.Config({'kBlockN': 256}, num_warps=8, num_stages=3), + # triton.Config({'kBlockN': 256}, num_warps=8, num_stages=4), + # triton.Config({'kBlockN': 256}, num_warps=16, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=8, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=16, num_stages=2), + # triton.Config({'kBlockN': 512}, num_warps=16, num_stages=3), + # triton.Config({'kBlockN': 1024}, num_warps=16, num_stages=2), + ], + key=['HEAD_DIM', 'POOL_BLOCK_SIZE'], +) +@triton.jit +def mean_pool_kernel( + # Pointers to matrices + input_ptr, + output_ptr, + # Matrix dimensions + HEAD_DIM: tl.constexpr, + POOL_BLOCK_SIZE: tl.constexpr, + cu_seqlens_input, + cu_seqlens_output, + input_stride_row, input_stride_head, + output_stride_row, output_stride_head, + # Meta-parameters + kBlockN: tl.constexpr, +): + """ + Triton kernel for mean pooling over variable-length sequences. + + This kernel computes the mean of non-overlapping blocks of size `POOL_BLOCK_SIZE` + for each sequence in a batch. It is designed to handle variable sequence lengths. + + Args: + input_ptr: Pointer to the input tensor of shape (total_seqlen, num_heads, head_dim). + output_ptr: Pointer to the output tensor of shape (total_blocks, num_heads, head_dim). + HEAD_DIM: The dimension of each head. + POOL_BLOCK_SIZE: The size of the pooling window. + cu_seqlens_input: Cumulative sequence lengths of the input tensor, shape (batch_size + 1,). + cu_seqlens_output: Cumulative sequence lengths of the output tensor, shape (batch_size + 1,). + input_stride_row: Stride of the input tensor along the sequence dimension. + input_stride_head: Stride of the input tensor along the head dimension. + output_stride_row: Stride of the output tensor along the sequence dimension. + output_stride_head: Stride of the output tensor along the head dimension. + kBlockN: Block size for the sequence dimension, a meta-parameter for tuning. + """ + n_block = tl.program_id(0) + bidb = tl.program_id(1) + bidh = tl.program_id(2) + + seq_start = tl.load(cu_seqlens_input + bidb) + seq_end = tl.load(cu_seqlens_input + bidb + 1) + + block_start_row = seq_start + n_block * POOL_BLOCK_SIZE + + if seq_end <= block_start_row: + return + + actual_block_size = tl.minimum(POOL_BLOCK_SIZE, seq_end - block_start_row) + + offsets_d = tl.arange(0, HEAD_DIM) + # mask_d = offsets_d < HEAD_DIM + + acc = tl.zeros([HEAD_DIM], dtype=tl.float32) + + for block_k_start in range(0, actual_block_size, kBlockN): + offsets_k = block_k_start + tl.arange(0, kBlockN) + mask_k = offsets_k < actual_block_size + + row_indices = block_start_row + offsets_k + + input_offset = row_indices[:, None] * input_stride_row.to(tl.int64) + bidh * input_stride_head.to(tl.int64) + offsets_d[None, :] + + inp = tl.load(input_ptr + input_offset, mask=mask_k[:, None], other=0.0) + acc += tl.sum(inp, axis=0) + + # safe division + mean_val = acc / actual_block_size + + output_start = tl.load(cu_seqlens_output + bidb) + output_offset = (output_start + n_block) * output_stride_row.to(tl.int64) + bidh * output_stride_head.to(tl.int64) + offsets_d + tl.store(output_ptr + output_offset, mean_val) + + +def flash_topk_mean_pool(input, cu_seqlens_input, max_seqlen_input, pool_block_size): + """ + Performs mean pooling on variable-length sequences using a Triton kernel. + + This function takes a tensor of packed sequences and applies mean pooling over + fixed-size blocks. + + Args: + input (torch.Tensor): The input tensor of shape (total_seqlen, num_heads, head_dim). + cu_seqlens_input (torch.Tensor): Cumulative sequence lengths for the input, shape (batch_size + 1,). + max_seqlen_input (int): The maximum sequence length in the input batch. + pool_block_size (int): The size of the pooling window. + + Returns: + Tuple[torch.Tensor, torch.Tensor, int]: A tuple containing: + - output (torch.Tensor): The pooled output tensor of shape (total_blocks, num_heads, head_dim). + - cu_seqlens_output (torch.Tensor): Cumulative sequence lengths for the output. + - max_seqlen_output (int): The maximum number of blocks for any sequence in the batch. + """ + total_seqlen, head_num, head_dim = input.shape + batch_size = cu_seqlens_input.shape[0] - 1 + + max_seqlen_output = (max_seqlen_input + pool_block_size - 1) // pool_block_size + + actual_input_seqlens = cu_seqlens_input[1:] - cu_seqlens_input[:-1] + actual_output_seqlens = (actual_input_seqlens + pool_block_size - 1) // pool_block_size + cu_seqlens_output = F.pad(torch.cumsum(actual_output_seqlens, dim=0), (1, 0)).to(torch.int32) + + total_blocks = cu_seqlens_output[-1].item() + + output = torch.zeros((total_blocks, head_num, head_dim), dtype=input.dtype, device=input.device) + + grid = (max_seqlen_output, batch_size, head_num) + + mean_pool_kernel[grid]( + input, + output, + head_dim, + pool_block_size, + cu_seqlens_input, + cu_seqlens_output, + input.stride(0), input.stride(1), + output.stride(0), output.stride(1), + ) + + return output, cu_seqlens_output, max_seqlen_output + \ No newline at end of file diff --git a/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py b/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py new file mode 100644 index 00000000..acca2ac8 --- /dev/null +++ b/vortex_torch/attention_backend/fsa/FSA_topk_sparse_attention.py @@ -0,0 +1,2040 @@ +# Copyright 2025 Ran Yan. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from ..nsa.topk_sparse_attention import (backward_sum_o_do, + reorder_topk_idx, + get_num_warps_stages, + is_hopper_gpu) + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def fused_fill_kernel(ptr_tile, ptr_m_i_cur_tiles, N, BLOCK_SIZE: tl.constexpr): + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < N + + tl.store(ptr_tile + offsets, -1, mask=mask) # fill int32 with -1 + tl.store(ptr_m_i_cur_tiles + offsets, float("-inf"), mask=mask) + + +def fused_fill(topk_idx_permuted_tile: torch.Tensor, m_i_cur_tiles): + + numel = topk_idx_permuted_tile.numel() + BLOCK_SIZE = 1024 + + # Flatten for pointer access + tile_flat = topk_idx_permuted_tile.view(-1) + + m_i_cur_tiles_flat = m_i_cur_tiles.view(-1) + + grid = lambda meta: (triton.cdiv(numel, meta['BLOCK_SIZE']),) + + fused_fill_kernel[grid]( + tile_flat, + m_i_cur_tiles_flat, + numel, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=1, + num_stages=3, + ) + + +@triton.jit +def block_to_token_kernel( + topk_idx_ptr, + result_ptr, + N_token, + K, + min_block_id, + max_block_id, + padding_value, + ts_h, + ts_b, + ts_n, + rs_h, + rs_b, + rs_n, + num_q_loops: tl.constexpr, + BLOCK_K: tl.constexpr, +): + pid = tl.program_id(0) # token index i + pid_h = 0 + offs = tl.arange(0, BLOCK_K) # [0, 1, ..., K-1] + + offs_q = tl.arange(0, num_q_loops) + + pid_j = pid * num_q_loops + offs_q + + topk_idx_offset = pid_h * ts_h + pid_j[None, :] * K + offs[:, None] + block_ids = tl.load( + topk_idx_ptr + topk_idx_offset, mask=(pid_j < N_token)[None, :] & (offs < K)[:, None], other=padding_value + ) + + result_ptrs = result_ptr + pid_h * rs_h + block_ids * N_token + pid_j[None, :] + + mask = (block_ids >= 0) & (block_ids != padding_value) & (pid_j < N_token)[None, :] + tl.store(result_ptrs, pid_j[None, :], mask=mask) + + +def build_block_to_token_triton( + result: torch.Tensor, topk_idx: torch.Tensor, min_block_id: int, max_block_id: int, padding_value: int = -1 +): + """ + Args: + topk_idx: [num_heads, N_token, TopK], block indices per token, padded with padding_value for invalid blocks + num_blocks: int + padding_value: int + + Returns: + result: [num_blocks, N_token], token indices per block, padded by padding_value + """ + assert topk_idx.ndim == 3 + assert padding_value == -1 + num_heads, N_token, TopK = topk_idx.shape + + # 每个 token,每个head 一个 program + num_q_loops = 4 + grid = (triton.cdiv(N_token, num_q_loops),) + BLOCK_K = triton.next_power_of_2(TopK) + block_to_token_kernel[grid]( + topk_idx, + result, + N_token, + TopK, + min_block_id, + max_block_id, + padding_value, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + result.stride(0), + result.stride(1), + result.stride(2), + num_q_loops, + BLOCK_K=BLOCK_K, + num_warps=2, + num_stages=3, + ) + return result + + +@triton.jit +def reduce_kernel( + lse_ptr, # float32 [H, N] + m_ij_ptr, # float32 [H, B, N] + l_ij_first_ptr, # float32 [H, 1, N] + l_ij_rest_ptr, # float32 [H, B, N] + m_ij_last_ptr, # float32 [H, N] + o_ptr, # o: n x h x d + o_tiles_first_ptr, # o_tiles: n x h x 1 x d + o_tiles_rest_ptr, # o_tiles: n x h x b x d + acc_o_scales_first_ptr, # acc_o_scales: n x h x 1 + acc_o_scales_rest_ptr, # acc_o_scales: n x h x b + t_ptr, # topk_idx: h x n x k + token_index_mapping_ptr, + start_head_id, + num_qz_loop, + TOPK, + total_len, + # stride + stride_lse_h, + stride_lse_n, + stride_m_ij_h, + stride_m_ij_b, + stride_m_ij_n, + stride_l_ij_fh, + stride_l_ij_fb, + stride_l_ij_fn, + stride_l_ij_rh, + stride_l_ij_rb, + stride_l_ij_rn, + stride_on, + stride_oh, + stride_od, + stride_otfh, + stride_otfb, + stride_otfn, + stride_otfd, + stride_otrh, + stride_otrb, + stride_otrn, + stride_otrd, + stride_acc_fh, + stride_acc_fb, + stride_acc_fn, + stride_acc_rh, + stride_acc_rb, + stride_acc_rn, + stride_th, + stride_tn, + stride_tk, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_qy = tl.program_id(0) + pid_q = tl.program_id(1) # token + + pid_q_j = pid_q + pid_qy * num_qz_loop + if pid_q_j < total_len: + t_ptr_j = t_ptr + pid_q_j * stride_tn + + off_d = tl.arange(0, BLOCK_SIZE_D) + o_ptrs = o_ptr + pid_q_j * stride_on + off_d + last_acc_o = tl.load(o_ptrs, mask=off_d < BLOCK_SIZE_D, other=0.0) + acc_o = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + acc_o += last_acc_o + + lse_ptrs = lse_ptr + pid_q_j * stride_lse_n + # Load lse + lse = tl.load(lse_ptrs, mask=pid_q_j < total_len, other=float("-inf")) + + # the stride is 1 for m_ij_last + m_ij_last = tl.load(m_ij_last_ptr + pid_q_j) + + for block_id in range(TOPK): + t = tl.load(t_ptr_j + block_id * stride_tk, mask=block_id < TOPK, other=-1) + if t != -1: + if t == 0: + real_block_pos = 0 + l_ij_ptr = l_ij_first_ptr + o_tiles_ptr = o_tiles_first_ptr + acc_o_scales_ptr = acc_o_scales_first_ptr + stride_l_ij_b = stride_l_ij_fb + stride_l_ij_n = stride_l_ij_fn + stride_acc_b = stride_acc_fb + stride_acc_n = stride_acc_fn + stride_otb = stride_otfb + stride_otn = stride_otfn + else: + real_block_pos = t - 1 + l_ij_ptr = l_ij_rest_ptr + o_tiles_ptr = o_tiles_rest_ptr + acc_o_scales_ptr = acc_o_scales_rest_ptr + stride_l_ij_b = stride_l_ij_rb + stride_l_ij_n = stride_l_ij_rn + stride_acc_b = stride_acc_rb + stride_acc_n = stride_acc_rn + stride_otb = stride_otrb + stride_otn = stride_otrn + + # init pointers + token_index_mapping_ptrs = ( + token_index_mapping_ptr + t.to(tl.int64) * stride_tim_b + (pid_q_j) * stride_tim_n + ) + real_token_index = tl.load(token_index_mapping_ptrs) + + m_ij = tl.load( + m_ij_ptr + t * stride_m_ij_b + pid_q_j * stride_m_ij_n, mask=pid_q_j < total_len, other=float("-inf") + ) + l_ij = tl.load( + l_ij_ptr + real_block_pos * stride_l_ij_b + real_token_index * stride_l_ij_n, + mask=real_token_index < total_len, + other=0.0, + ) + delta = lse - m_ij + + log_delta = tl.exp2(delta) + l_ij + + # Update lse + lse = m_ij + tl.log2(log_delta) + + o_tiles_ptrs = ( + o_tiles_ptr + real_block_pos.to(tl.int64) * stride_otb + (real_token_index) * stride_otn + off_d + ) + acc_o_scales_ptrs = acc_o_scales_ptr + real_block_pos * stride_acc_b + (real_token_index) * stride_acc_n + + o_tiles = tl.load(o_tiles_ptrs) + acc_o_scales_tiles = tl.load(acc_o_scales_ptrs) + acc_o = o_tiles + acc_o * acc_o_scales_tiles + + # final scale + acc_o = acc_o * tl.exp2(m_ij_last - lse) + tl.store(o_ptrs, acc_o, mask=off_d < BLOCK_SIZE_D) + + # Store back + tl.store( + lse_ptrs, + lse, + mask=pid_q_j < total_len, + ) + + +@triton.jit +def qk_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + m_i_tiles_ptr, # m_i: h x b x n + selected_tokens_ptr, # selected_tokens: sum(valid_lens), + valid_lens_ptr, # valid_lens: (h x b), + valid_start_indices_ptr, # valid_start_indices: (h x b), + num_heads, + num_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + HEAD_DIM, + # sm_scale + sm_scale, + num_q_blocks, + num_b_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_m_i_tiles_h, + stride_m_i_tiles_b, + stride_m_i_tiles_n, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_block_grid = tl.program_id(0) // num_heads # block id + head_id = tl.program_id(0) % num_heads + pid_q = tl.program_id(1) # token + + # get q k start and len after rmpad + k_len = tl.load(cu_seqlens_k + 1) + k_ptrs = tl.make_block_ptr( + base=k_ptr + head_id * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + + for bb in range(num_b_blocks): + pid_block = bb + pid_block_grid * num_b_blocks + + start_id = tl.load(valid_start_indices_ptr + head_id * num_blocks + pid_block) + valid_tokens = tl.load(valid_lens_ptr + head_id * num_blocks + pid_block) + if pid_q * BLOCK_SIZE_Q < valid_tokens: + + c = pid_block * BLOCK_SIZE_K + + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_q_blocks): + pid_q_j = pid_q * num_q_blocks + j + # Enable early return + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + q_ptrs = q_ptr + head_id * stride_qh + q_ptrs_off + # load q + q_mask = (st != -1)[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + # [BLOCK_SIZE_Q, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_Q, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + + m_i = tl.max(qk, axis=1) + + m_i_tiles_ptrs = ( + m_i_tiles_ptr + + head_id * stride_m_i_tiles_h + + pid_block * stride_m_i_tiles_b + + st * stride_m_i_tiles_n + ) + tl.store(m_i_tiles_ptrs, m_i, mask=(st != -1)) + + +@triton.jit +def forward_kernel_opt( + q_ptr, + k_ptr, + v_ptr, # V: n x h x d + o_tiles_ptr, # O: n x h x b x d + acc_o_scales_ptr, # acc_o_scales: h x b x n + m_ij_tiles_ptr, + l_ij_ptr, # h x b x n + token_index_mapping_ptr, + selected_tokens_ptr, # selected_tokens: sum(valid_lens), + valid_lens_ptr, # valid_lens: (h x b), + valid_start_indices_ptr, # valid_start_indices: (h x b), + min_block_id, + cur_max_valid_tokens, + num_heads, + num_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + HEAD_DIM, + # sm_scale + sm_scale, + num_q_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_oth, + stride_otb, + stride_otn, + stride_otd, + stride_acc_oh, + stride_acc_ob, + stride_acc_on, + stride_m_ij_tiles_h, + stride_m_ij_tiles_b, + stride_m_ij_tiles_n, + stride_l_ij_h, + stride_l_ij_b, + stride_l_ij_n, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + # get batch id and head id + pid_block = tl.program_id(0) // num_heads # block id + head_id = tl.program_id(0) % num_heads + pid_q = tl.program_id(1) # token + # seq packing is not supported yet + q_start = 0 + k_start = 0 + + k_len = tl.load(cu_seqlens_k + 1) - k_start + + start_id = tl.load(valid_start_indices_ptr + head_id * num_blocks + pid_block) + valid_tokens = tl.load(valid_lens_ptr + head_id * num_blocks + pid_block) + if num_q_blocks * pid_q * BLOCK_SIZE_Q >= valid_tokens: + return + + c = (min_block_id + pid_block) * BLOCK_SIZE_K + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + head_id * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + head_id * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + + # load v + v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_q_blocks): + pid_q_j = pid_q * num_q_blocks + j + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (q_start + pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + + # load m_i + mask = st != -1 + + m_ij_tiles_ptrs = ( + m_ij_tiles_ptr + + head_id * stride_m_ij_tiles_h + + (q_start + st) * stride_m_ij_tiles_n + + (pid_block + min_block_id) * stride_m_ij_tiles_b + ) + m_ij = tl.load(m_ij_tiles_ptrs, mask=mask, other=float("-inf")) + + m_ij_tiles_prev_ptrs = ( + m_ij_tiles_ptr + + head_id * stride_m_ij_tiles_h + + (q_start + st) * stride_m_ij_tiles_n + + (pid_block + min_block_id - 1) * stride_m_ij_tiles_b + ) + m_ij_prev = tl.load(m_ij_tiles_prev_ptrs, mask=mask & (pid_block + min_block_id > 0), other=float("-inf")) + + m_i_minus_m_ij = m_ij_prev - m_ij + + q_ptrs = q_ptr + q_start * stride_qn + head_id * stride_qh + q_ptrs_off + # load q + q_mask = mask[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + + # [BLOCK_SIZE_Q, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_Q, BLOCK_SIZE_K] + qk_scale = sm_scale * 1.44269504 + qk += tl.dot(q, k) * qk_scale + + # init statistics + acc_o_buffer = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + + # load m_ij and compute l_ij + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + + # load token index mapping + token_index_mapping_ptrs = ( + token_index_mapping_ptr + (st) * stride_tim_n + (pid_block + min_block_id) * stride_tim_b + ) + token_index_mapping = tl.load(token_index_mapping_ptrs, mask=mask, other=-1) + + l_ij_ptrs = ( + l_ij_ptr + + head_id * stride_l_ij_h + + (q_start + token_index_mapping) * stride_l_ij_n + + (pid_block) * stride_l_ij_b + ) + tl.store(l_ij_ptrs, l_ij, mask=mask) + # scale acc_o + if pid_block + min_block_id == 0: + acc_o_scale = tl.full((BLOCK_SIZE_Q,), 1.0, dtype=tl.float32) + else: + acc_o_scale = tl.exp2(m_i_minus_m_ij) + + tl.store( + acc_o_scales_ptr + + head_id * stride_acc_oh + + (pid_block) * stride_acc_ob + + (q_start + token_index_mapping) * stride_acc_on, + acc_o_scale, + mask=(st != -1), + ) + + p = p.to(v.dtype) + acc_o_buffer = tl.dot(p, v) + + o_ptrs_off = token_index_mapping[:, None] * stride_otn + off_d[None, :] * stride_otd + o_ptrs = o_tiles_ptr + head_id * stride_oth + o_ptrs_off + (pid_block).to(tl.int64) * stride_otb + tl.store(o_ptrs, acc_o_buffer.to(o_tiles_ptr.dtype.element_ty), mask=q_mask) + + +def _topk_sparse_attention_fwd_opt( + q: torch.Tensor, # [total_len, num_heads, head_dim] + k: torch.Tensor, # [total_len, num_heads, head_dim] + v: torch.Tensor, # [total_len, num_heads, head_dim] + topk_idx: torch.Tensor, # [num_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + causal=True, +): + """ + TODO: Currently sequence packing is explicitly done in for loop, will merge in kernels. + """ + o = torch.empty_like(q) + total_len, num_heads, _ = q.shape + lse = torch.empty((num_heads, total_len), dtype=torch.float32, device=q.device) + + permute_results = [] + for i in range(len(cu_seqlens_q) - 1): + cu_seqlens_q_ = cu_seqlens_q[i: i + 2] - cu_seqlens_q[i] + cu_seqlens_k_ = cu_seqlens_k[i: i + 2] - cu_seqlens_k[i] + max_seqlen_q_ = cu_seqlens_q_[1] - cu_seqlens_q_[0] + max_seqlen_k_ = cu_seqlens_k_[1] - cu_seqlens_k_[0] + + q_ = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_ = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + v_ = v[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + topk_idx_ = topk_idx[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + o_seq, lse_seq, permute_results_seq = _topk_sparse_attention_fwd_opt_per_seq( + q_, + k_, + v_, + topk_idx_, + block_size, + cu_seqlens_q_, + cu_seqlens_k_, + max_seqlen_q_, + max_seqlen_k_, + sm_scale, + causal, + ) + o[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = o_seq + + lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = lse_seq + permute_results.append(permute_results_seq) + + return o, lse, permute_results + + +@triton.jit +def index_mapping_kernel( + token_index_mapping_ptr, + selected_tokens_ptr, + valid_lens_ptr, + valid_start_indices_ptr, + stride_im_h, + stride_im_b, + stride_im_n, + BLOCK_SIZE_K: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_n = tl.program_id(1) + + offs_q = tl.arange(0, BLOCK_SIZE_K) + offs_n = pid_n * BLOCK_SIZE_K + offs_q + + start_id = tl.load(valid_start_indices_ptr + pid_b) + valid_tokens = tl.load(valid_lens_ptr + pid_b) + + st_offs = start_id + offs_n + # st should be in shape [BLOCK_SIZE_K] + st_mask = offs_n < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + + token_im_ptrs = token_index_mapping_ptr + pid_b * stride_im_b + st * stride_im_n + + tl.store(token_im_ptrs, offs_n, mask=st_mask) + + +def index_mapping(token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks): + max_tokens = valid_lens.max() + BLOCK_SIZE_K = 1024 + grid = (num_blocks, triton.cdiv(max_tokens, BLOCK_SIZE_K)) + + index_mapping_kernel[grid]( + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_K, + num_warps=2, + num_stages=3, + ) + + +def online_softmax( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + block_size, + num_blocks, + head_tile, + head_dim, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, +): + + # launch kernel + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_q_blocks = 8 + num_b_blocks = 1 + grid_qk = lambda META: ( + triton.cdiv(num_blocks, num_b_blocks), + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_q_blocks), + ) + qk_kernel[grid_qk]( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + head_tile, + num_blocks, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + num_q_blocks, + num_b_blocks, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + m_i_cur_tiles.stride(0), + m_i_cur_tiles.stride(1), + m_i_cur_tiles.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) + + m_ij_tiles = m_i_cur_tiles.cummax(dim=1).values + m_ij_last = m_ij_tiles[:, -1] + + return m_ij_tiles, m_ij_last + + +def qkv_kernel( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + block_size, +): + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + # a heuristic that avoids large grid size, and redudant KV loading + num_q_blocks = 8 + + grid_fwd = lambda META: ( + compute_tile_size * head_tile, + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_q_blocks), + ) + + forward_kernel_opt[grid_fwd]( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + num_q_blocks, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + v_tile.stride(0), + v_tile.stride(1), + v_tile.stride(2), + o_tiles.stride(0), + o_tiles.stride(1), + o_tiles.stride(2), + o_tiles.stride(3), + acc_o_scales.stride(0), + acc_o_scales.stride(1), + acc_o_scales.stride(2), + m_ij_tiles.stride(0), + m_ij_tiles.stride(1), + m_ij_tiles.stride(2), + l_ij.stride(0), + l_ij.stride(1), + l_ij.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_stages=3, + num_warps=4, + ) + + +def reduce_output( + lse, + o, + o_tiles_first, + o_tiles_rest, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h, + head_tile, + total_len, + TOPK, + head_dim, +): + + num_qy_loop = 4 + num_qz_loop = total_len // num_qy_loop + + grid_reduce = lambda META: ( + num_qy_loop + (total_len % num_qy_loop != 0), + num_qz_loop, + ) + + reduce_kernel[grid_reduce]( + lse, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + o, + o_tiles_first, + o_tiles_rest, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h * head_tile, + num_qz_loop, + TOPK, + total_len, + lse.stride(0), + lse.stride(1), + m_ij_tiles.stride(0), + m_ij_tiles.stride(1), + m_ij_tiles.stride(2), + l_ij_first.stride(0), + l_ij_first.stride(1), + l_ij_first.stride(2), + l_ij_rest.stride(0), + l_ij_rest.stride(1), + l_ij_rest.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + o_tiles_first.stride(0), + o_tiles_first.stride(1), + o_tiles_first.stride(2), + o_tiles_first.stride(3), + o_tiles_rest.stride(0), + o_tiles_rest.stride(1), + o_tiles_rest.stride(2), + o_tiles_rest.stride(3), + acc_o_scales_first.stride(0), + acc_o_scales_first.stride(1), + acc_o_scales_first.stride(2), + acc_o_scales_rest.stride(0), + acc_o_scales_rest.stride(1), + acc_o_scales_rest.stride(2), + topk_idx_tile.stride(0), + topk_idx_tile.stride(1), + topk_idx_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_T=triton.next_power_of_2(TOPK), + BLOCK_SIZE_D=triton.next_power_of_2(head_dim), + num_warps=1, + num_stages=2, + ) + + +def _topk_sparse_attention_fwd_opt_per_seq( + q: torch.Tensor, # [total_len, num_heads, head_dim] + k: torch.Tensor, # [total_len, num_kv_heads, head_dim] + v: torch.Tensor, # [total_len, num_kv_heads, head_dim] + topk_idx: torch.Tensor, # [num_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + causal=True, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert block_size in {16, 32, 64, 128, 256} + # shape + + total_len, num_heads, head_dim = q.shape + total_len, num_kv_heads, head_dim = k.shape + + assert num_heads % num_kv_heads == 0 + gqa_deg = num_heads // num_kv_heads + + TOPK = topk_idx.shape[-1] + + real_num_blocks = math.ceil(total_len / block_size) + num_blocks = max(real_num_blocks, TOPK) + + head_tile = 1 + reduce_tile_size = num_blocks - 1 + + valid_lens_all = torch.zeros( + ( + num_kv_heads, + num_blocks, + ), + dtype=torch.int32, + device=q.device, + ) + for h in range(num_kv_heads): + topk_idx_tile = topk_idx[h * head_tile: (h + 1) * head_tile] + topk_idx_nonneg = topk_idx_tile[topk_idx_tile >= 0] + valid_lens = torch.bincount(topk_idx_nonneg.view(-1), minlength=num_blocks) + valid_lens_all[h * head_tile: (h + 1) * head_tile] = valid_lens + + global_max_valid_tokens = valid_lens_all[:, 1:].max() if num_blocks > 1 else valid_lens_all.max() + + o_full = torch.zeros_like(q) + lse_full = torch.full((num_heads, total_len), float("-inf"), dtype=torch.float32, device=q.device) + + # New introduced buffers + topk_idx_permuted_tile = torch.full((head_tile, num_blocks, total_len), -1, dtype=torch.int32, device=q.device) + + token_index_mapping = torch.full((head_tile, num_blocks, total_len), 0, dtype=torch.int32, device=q.device) + # first KV block is computed seaprately + o_tiles_first = torch.zeros((head_tile, 1, total_len, head_dim), dtype=torch.bfloat16, device=q.device) + o_tiles_rest = torch.zeros( + (head_tile, reduce_tile_size, global_max_valid_tokens, head_dim), dtype=torch.bfloat16, device=q.device + ) + + # Statistics buffers + # m_i_tiles: 历史最大, m_diff_tiles: 历史最大和当前最大的差值 + # m_i_cur_tiles: 当前最大, # m_ij_tiles: 考虑当前和历史后的最大 + m_i_cur_tiles: torch.Tensor = torch.full( + (head_tile, num_blocks, total_len), float("-inf"), dtype=torch.float32, device=q.device + ) + + # first KV block is reduced separately + l_ij_first = torch.full((head_tile, 1, total_len), 0, dtype=torch.float32, device=q.device) + acc_o_scales_first = torch.full((head_tile, 1, total_len), 1, dtype=torch.float32, device=q.device) + + l_ij_rest = torch.full( + (head_tile, reduce_tile_size, global_max_valid_tokens), 0, dtype=torch.float32, device=q.device + ) + acc_o_scales_rest = torch.full( + (head_tile, reduce_tile_size, global_max_valid_tokens), 1, dtype=torch.float32, device=q.device + ) + + permute_results = {} + permute_results['global_max_valid_tokens'] = global_max_valid_tokens + permute_results['num_blocks'] = num_blocks + permute_results['real_num_blocks'] = real_num_blocks + permute_results['valid_topk_idx_permuted_tile'] = [] + permute_results['valid_lens_all'] = valid_lens_all + permute_results['valid_lens'] = [] + permute_results['valid_start_indices'] = [] + + for h in range(num_heads // head_tile): + q_tile = q[:, h * head_tile: (h + 1) * head_tile] + k_tile = k[:, (h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + v_tile = v[:, (h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + o = o_full[:, h * head_tile: (h + 1) * head_tile] + lse = lse_full[h * head_tile: (h + 1) * head_tile] + + permute_min_block_id = 0 + permute_max_block_id = min(permute_min_block_id + num_blocks, num_blocks) + + topk_idx_tile = topk_idx[(h // gqa_deg) * head_tile: ((h // gqa_deg + 1)) * head_tile] + + if h % gqa_deg == 0: + topk_idx_permuted_tile = build_block_to_token_triton( + topk_idx_permuted_tile, topk_idx_tile, permute_min_block_id, permute_max_block_id, padding_value=-1 + ) + + valid_topk_idx_permuted_tile = topk_idx_permuted_tile[topk_idx_permuted_tile != -1] + valid_lens = valid_lens_all[(h // gqa_deg) * head_tile, :] + valid_start_indices = torch.nn.functional.pad(valid_lens.cumsum(0)[:-1], (1, 0), value=0) + + index_mapping( + token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks + ) + + permute_results['valid_topk_idx_permuted_tile'].append(valid_topk_idx_permuted_tile) + permute_results['valid_lens'].append(valid_lens) + permute_results['valid_start_indices'].append(valid_start_indices) + + m_ij_tiles, m_ij_last = online_softmax( + q_tile, + k_tile, + m_i_cur_tiles, + valid_topk_idx_permuted_tile, + valid_lens, + valid_start_indices, + 0, + total_len, + block_size, + num_blocks, + head_tile, + head_dim, + sm_scale, + cu_seqlens_q, + cu_seqlens_k, + ) + + m_ij_tiles[:, :, :] = m_ij_tiles[:, :, 0][:, :, None] + m_ij_last[:, :] = m_ij_last[:, 0] + for compute_min_block_id in range(min(2, num_blocks)): + if compute_min_block_id == 0: + cur_max_valid_tokens = total_len + cur_valid_lens = valid_lens[0] + cur_valid_start_indices = valid_start_indices[0] + o_tiles = o_tiles_first + l_ij = l_ij_first + acc_o_scales = acc_o_scales_first + compute_tile_size = 1 + else: + cur_max_valid_tokens = valid_lens[compute_min_block_id:].max() + cur_valid_lens = valid_lens[compute_min_block_id:] + cur_valid_start_indices = valid_start_indices[compute_min_block_id:] + o_tiles = o_tiles_rest + l_ij = l_ij_rest + acc_o_scales = acc_o_scales_rest + compute_tile_size = num_blocks - 1 + + # launch kernel + qkv_kernel( + q_tile, + k_tile, + v_tile, + o_tiles, + acc_o_scales, + m_ij_tiles, + l_ij, + token_index_mapping, + valid_topk_idx_permuted_tile, + cur_valid_lens, + cur_valid_start_indices, + compute_min_block_id, + cur_max_valid_tokens, + head_tile, + compute_tile_size, + cu_seqlens_q, + cu_seqlens_k, + head_dim, + sm_scale, + block_size, + ) + + reduce_output( + lse, + o, + o_tiles_first, + o_tiles_rest, + m_ij_tiles, + l_ij_first, + l_ij_rest, + m_ij_last, + acc_o_scales_first, + acc_o_scales_rest, + topk_idx_tile, + token_index_mapping, + h, + head_tile, + total_len, + TOPK, + head_dim, + ) + + o_full[:, h * head_tile: (h + 1) * head_tile] = o + lse_full[h * head_tile: (h + 1) * head_tile] = lse + + if h % gqa_deg == 0: + fused_fill(topk_idx_permuted_tile, m_i_cur_tiles) + + return o_full, lse_full, permute_results + + +@triton.jit +def dq_compute_kernel( + q_ptr, + k_ptr, + v_ptr, + lse_ptr, + delta_ptr, + do_ptr, + dq_tiles_ptr, + token_index_mapping_ptr, + selected_tokens_ptr, + valid_lens_ptr, + valid_start_indices_ptr, + cur_max_valid_tokens, + compute_min_block_id, + head_tile, + num_blocks, + HEAD_DIM, + cu_seqlens_k, + num_dq_blocks, + sm_scale, + debug_ptr, + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tim_h, + stride_tim_b, + stride_tim_n, + stride_dqth, + stride_dqtb, + stride_dqtn, + stride_dqtd, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + + pid_block = tl.program_id(0) + pid_q = tl.program_id(1) # token + # seq packing is not supported yet + q_start = 0 + k_start = 0 + + k_len = tl.load(cu_seqlens_k + 1) - k_start + + start_id = tl.load(valid_start_indices_ptr + pid_block) + valid_tokens = tl.load(valid_lens_ptr + pid_block) + if num_dq_blocks * pid_q * BLOCK_SIZE_Q >= valid_tokens: + return + + c = (pid_block + compute_min_block_id) * BLOCK_SIZE_K + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + + # load k + k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + + # load v + v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") + + qk_scale = sm_scale * 1.44269504 + + off_k = tl.arange(0, BLOCK_SIZE_K) + off_d = tl.arange(0, BLOCK_SIZE_D) + for j in range(num_dq_blocks): + pid_q_j = pid_q * num_dq_blocks + j + if pid_q_j * BLOCK_SIZE_Q < valid_tokens: + # one thread block for one KV block, a subset of selected tokens + st_offs = start_id + (q_start + pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) + # st should be in shape [BLOCK_SIZE_Q] + st_mask = (pid_q_j * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q)) < valid_tokens + + st = tl.load(selected_tokens_ptr + st_offs, mask=st_mask, other=-1) + tl.store(debug_ptr + tl.arange(0, BLOCK_SIZE_Q), st_offs) + # otherwise, st selects a set of q tokens, selected_tokens_ptr should be sorted + q_ptrs_off = st[:, None] * stride_qn + off_d[None, :] * stride_qd + + mask = st != -1 + + q_ptrs = q_ptr + q_start * stride_qn + q_ptrs_off + # load q + q_mask = mask[:, None] & (off_d < HEAD_DIM)[None, :] + q = tl.load(q_ptrs, mask=q_mask, other=0) + do_ptrs = do_ptr + q_start * stride_qn + q_ptrs_off + do = tl.load(do_ptrs, mask=q_mask, other=0) + delta_ptrs = delta_ptr + st[:, None] + d = tl.load(delta_ptrs, mask=mask[:, None], other=0) + lse_ptrs = lse_ptr + st[:, None] + lse = tl.load(lse_ptrs, mask=mask[:, None], other=0) + + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((st[:, None] >= c + off_k[None, :]), 0, float("-inf")) + qk += tl.dot(q, tl.trans(k)) * qk_scale # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + p = tl.exp2(qk - lse) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + dp = tl.dot(do, v) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + ds = sm_scale * p * (dp - d) # [BLOCK_SIZE_Q, BLOCK_SIZE_K] + ds = ds.to(q.dtype) + dq = tl.dot(ds, k) # [BLOCK_SIZE_Q, BLOCK_SIZE_D] + + # load token index mapping + token_index_mapping_ptrs = ( + token_index_mapping_ptr + (st) * stride_tim_n + (pid_block + compute_min_block_id) * stride_tim_b + ) + token_index_mapping = tl.load(token_index_mapping_ptrs, mask=mask, other=-1) + + dq_ptrs_off = token_index_mapping[:, None] * stride_dqtn + off_d[None, :] * stride_dqtd + dq_tiles_ptrs = dq_tiles_ptr + dq_ptrs_off + (pid_block).to(tl.int64) * stride_dqtb + tl.store(dq_tiles_ptrs, dq.to(dq_tiles_ptr.dtype.element_ty), mask=q_mask) + + +@triton.jit +def dq_reduce_kernel( + dq_buffer_first_ptr, # [H, 1, N, D] + dq_buffer_rest_ptr, # [H, B, N, D] + dq_ptr, # o: n x h x d + t_ptr, # topk_idx: h x n x k + token_index_mapping_ptr, + num_qz_loop, + TOPK, + total_len, + # stride + stride_dqtfh, + stride_dqtfb, + stride_dqtfn, + stride_dqtfd, + stride_dqtrh, + stride_dqtrb, + stride_dqtrn, + stride_dqtrd, + stride_dqn, + stride_dqh, + stride_dqd, + stride_th, + stride_tn, + stride_tk, + stride_tim_h, + stride_tim_b, + stride_tim_n, + # META parameters + BLOCK_SIZE_T: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_qy = tl.program_id(0) + pid_q = tl.program_id(1) # token + + pid_q_j = pid_q + pid_qy * num_qz_loop + if pid_q_j < total_len: + t_ptr_j = t_ptr + pid_q_j * stride_tn + + off_d = tl.arange(0, BLOCK_SIZE_D) + dq_ptrs = dq_ptr + pid_q_j * stride_dqn + off_d + acc_dq = tl.zeros((BLOCK_SIZE_D,), dtype=tl.float32) + + for block_id in range(TOPK): + t = tl.load(t_ptr_j + block_id * stride_tk, mask=block_id < TOPK, other=-1) + if t != -1: + if t == 0: + dq_buffer_ptr = dq_buffer_first_ptr + stride_dqtb = stride_dqtfb + stride_dqtn = stride_dqtfn + real_block_pos = 0 + else: + dq_buffer_ptr = dq_buffer_rest_ptr + stride_dqtb = stride_dqtrb + stride_dqtn = stride_dqtrn + real_block_pos = t - 1 + + # init pointers + token_index_mapping_ptrs = ( + token_index_mapping_ptr + t.to(tl.int64) * stride_tim_b + (pid_q_j) * stride_tim_n + ) + real_token_index = tl.load(token_index_mapping_ptrs) + + dq_buffer_ptrs = ( + dq_buffer_ptr + real_block_pos.to(tl.int64) * stride_dqtb + (real_token_index) * stride_dqtn + off_d + ) + + dq_buffers = tl.load(dq_buffer_ptrs) + acc_dq = dq_buffers + acc_dq + + tl.store(dq_ptrs, acc_dq, mask=off_d < BLOCK_SIZE_D) + + +def backward_dq_opt( + q, # [total_len, num_heads, head_dim] + k, # [total_len, num_k_heads, head_dim] + v, # [total_len, num_k_heads, head_dim] + topk_idx, # [num_k_heads, total_len, topk] + lse, # [num_heads, total_len] + delta, # [num_heads, total_len] + do, # [total_len, num_heads, head_dim] + dq, # [total_len, num_heads, head_dim] + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, +): + """ + TODO: Currently sequence packing is explicitly done in for loop, will merge in kernels. + """ + for i in range(len(cu_seqlens_q) - 1): + cu_seqlens_q_ = cu_seqlens_q[i: i + 2] - cu_seqlens_q[i] + cu_seqlens_k_ = cu_seqlens_k[i: i + 2] - cu_seqlens_k[i] + + permute_results_ = permute_results[i] + + q_ = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_ = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + v_ = v[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + topk_idx_ = topk_idx[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + lse_ = lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + delta_ = delta[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + do_ = do[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + dq_ = dq[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + + backward_dq_opt_per_seq( + q_, + k_, + v_, + topk_idx_, + lse_, + delta_, + do_, + dq_, + cu_seqlens_q_, + cu_seqlens_k_, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results_, + ) + + dq[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = dq_ + + return dq + + +def backward_dq_opt_per_seq( + q, # [total_len, num_k_heads, head_dim] + k, # [total_len, num_k_heads, head_dim] + v, # [total_len, num_k_heads, head_dim] + topk_idx, # [num_k_heads, total_len, topk] + lse, # [num_k_heads, total_len] + delta, # [num_k_heads, total_len] + do, # [total_len, num_k_heads, head_dim] + dq, # [total_len, num_k_heads, head_dim] + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, +): + head_tile = 1 + total_len = topk_idx.shape[1] + global_max_valid_tokens = permute_results['global_max_valid_tokens'] + num_blocks = permute_results['num_blocks'] + reduce_tile_size = num_blocks - 1 + dq_buffer_first = torch.zeros((head_tile, 1, total_len, head_dim), dtype=torch.bfloat16, device=dq.device) + dq_buffer_rest = torch.zeros( + (head_tile, reduce_tile_size, global_max_valid_tokens, head_dim), dtype=torch.bfloat16, device=dq.device + ) + + num_heads = num_share_q_heads * num_k_heads + + token_index_mapping = torch.full((head_tile, num_blocks, total_len), 0, dtype=torch.int32, device=q.device) + for h in range(num_heads // head_tile): + valid_topk_idx_permuted_tile = permute_results['valid_topk_idx_permuted_tile'][h // num_share_q_heads] + + valid_lens = permute_results['valid_lens'][h // num_share_q_heads] + valid_start_indices = permute_results['valid_start_indices'][h // num_share_q_heads] + + index_mapping(token_index_mapping, valid_topk_idx_permuted_tile, valid_lens, valid_start_indices, num_blocks) + q_tile = q[:, h * head_tile: (h + 1) * head_tile] + k_tile = k[:, (h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + v_tile = v[:, (h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + do_tile = do[:, h * head_tile: (h + 1) * head_tile] + lse_tile = lse[h * head_tile: (h + 1) * head_tile] + topk_idx_tile = topk_idx[(h // num_share_q_heads) * head_tile: ((h // num_share_q_heads + 1)) * head_tile] + delta_tile = delta[h * head_tile: (h + 1) * head_tile] + dq_tile = dq[:, h * head_tile: (h + 1) * head_tile] + + for compute_min_block_id in range(min(2, num_blocks)): + if compute_min_block_id == 0: + compute_tile_size = 1 + cur_max_valid_tokens = total_len + cur_valid_lens = valid_lens[0] + cur_valid_start_indices = valid_start_indices[0] + dq_buffer = dq_buffer_first + else: + compute_tile_size = num_blocks - 1 + cur_max_valid_tokens = valid_lens[compute_min_block_id:].max() + cur_valid_lens = valid_lens[compute_min_block_id:] + cur_valid_start_indices = valid_start_indices[compute_min_block_id:] + dq_buffer = dq_buffer_rest + + BLOCK_SIZE_Q = 128 + num_dq_blocks = 8 + grid_dq = lambda META: ( + compute_tile_size, + triton.cdiv(cur_max_valid_tokens, BLOCK_SIZE_Q * num_dq_blocks), + ) + + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + debug = torch.zeros((BLOCK_SIZE_Q,), dtype=torch.int32, device=dq.device) + dq_compute_kernel[grid_dq]( + q_tile, + k_tile, + v_tile, + lse_tile, + delta_tile, + do_tile, + dq_buffer, + token_index_mapping, + valid_topk_idx_permuted_tile, + cur_valid_lens, + cur_valid_start_indices, + cur_max_valid_tokens, + compute_min_block_id, + head_tile, + num_blocks, + head_dim, + cu_seqlens_k, + num_dq_blocks, + sm_scale, + debug, + q_tile.stride(0), + q_tile.stride(1), + q_tile.stride(2), + k_tile.stride(0), + k_tile.stride(1), + k_tile.stride(2), + v_tile.stride(0), + v_tile.stride(1), + v_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + dq_buffer.stride(0), + dq_buffer.stride(1), + dq_buffer.stride(2), + dq_buffer.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + + num_qy_loop = 4 + num_qz_loop = total_len // num_qy_loop + + grid_reduce = lambda META: ( + num_qy_loop + (total_len % num_qy_loop != 0), + num_qz_loop, + ) + dq_reduce_kernel[grid_reduce]( + dq_buffer_first, + dq_buffer_rest, + dq_tile, + topk_idx_tile, + token_index_mapping, + num_qz_loop, + topk, + total_len, + dq_buffer_first.stride(0), + dq_buffer_first.stride(1), + dq_buffer_first.stride(2), + dq_buffer_first.stride(3), + dq_buffer_rest.stride(0), + dq_buffer_rest.stride(1), + dq_buffer_rest.stride(2), + dq_buffer_rest.stride(3), + dq_tile.stride(0), + dq_tile.stride(1), + dq_tile.stride(2), + topk_idx_tile.stride(0), + topk_idx_tile.stride(1), + topk_idx_tile.stride(2), + token_index_mapping.stride(0), + token_index_mapping.stride(1), + token_index_mapping.stride(2), + BLOCK_SIZE_T=triton.next_power_of_2(topk), + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=1, + num_stages=2, + ) + + dq[:, h * head_tile: (h + 1) * head_tile] = dq_tile + + return dq + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) + act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _topk_sparse_attention_bwd_opt( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + permute_results, +): + + assert block_size in {16, 32, 64, 128, 256} + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(seqblocks, dim=0), + ] + ).to(torch.int32) + + topk_q_count = torch.cat( + [ + permute_results[i]['valid_lens_all'][:, : permute_results[i]['real_num_blocks']] + for i in range(len(permute_results)) + ], + dim=1, + ) + + cu_topk_q_count = torch.cat( + [ + torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + BLOCK_SIZE_K = block_size + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + + backward_dq_opt( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + block_size, + permute_results, + ) + + return dq, dk, dv + + +class FSATopkSparseAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + permute_results = None + + o, lse, permute_results = _topk_sparse_attention_fwd_opt( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) + ctx.permute_results = permute_results + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors + permute_results = ctx.permute_results + + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + assert block_size in {16, 32, 64, 128, 256} + + dq, dk, dv = _topk_sparse_attention_bwd_opt( + o, + do, + lse, + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + permute_results, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def FSA_topk_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention varlen version implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return FSATopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) + + +def FSA_topk_sparse_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """FSA topk sparse attention with separate Q and K sequence lengths (for extend/prefill). + + Args: + q (torch.Tensor): shape [total_q, num_q_heads, head_dim] + k (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_q, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], cumulative Q sequence lengths. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], cumulative K sequence lengths. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_q, num_q_heads, head_dim] + """ + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + return FSATopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + ) diff --git a/vortex_torch/attention_backend/fsa/__init__.py b/vortex_torch/attention_backend/fsa/__init__.py new file mode 100644 index 00000000..9efd4740 --- /dev/null +++ b/vortex_torch/attention_backend/fsa/__init__.py @@ -0,0 +1,9 @@ +from .FSA_topk_sparse_attention import ( + FSA_topk_sparse_attention, + FSA_topk_sparse_attention_varlen, +) + +__all__ = [ + "FSA_topk_sparse_attention", + "FSA_topk_sparse_attention_varlen", +] diff --git a/vortex_torch/attention_backend/nsa/__init__.py b/vortex_torch/attention_backend/nsa/__init__.py new file mode 100644 index 00000000..382da01b --- /dev/null +++ b/vortex_torch/attention_backend/nsa/__init__.py @@ -0,0 +1,9 @@ +from .topk_sparse_attention import ( + topk_sparse_attention, + topk_sparse_attention_varlen, +) + +__all__ = [ + "topk_sparse_attention", + "topk_sparse_attention_varlen", +] diff --git a/vortex_torch/attention_backend/nsa/topk_sparse_attention.py b/vortex_torch/attention_backend/nsa/topk_sparse_attention.py new file mode 100644 index 00000000..57a2be70 --- /dev/null +++ b/vortex_torch/attention_backend/nsa/topk_sparse_attention.py @@ -0,0 +1,1280 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +def is_hopper_gpu(): + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability(0) + major, minor = device_capability + return major == 9 + return False + + +def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): + head_large = head_dim > 64 + block_large = block_size > 64 + if is_hopper_gpu: + if head_large and block_large: + num_warps, num_stages = 8, 3 + elif head_large or block_large: + num_warps, num_stages = 4, 3 + else: + num_warps, num_stages = 2, 2 + else: + if head_large and block_large: + num_warps, num_stages = 8, 3 + elif head_large or block_large: + num_warps, num_stages = 8, 3 + else: + num_warps, num_stages = 2, 2 + return num_warps, num_stages + + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel_orig( + q_ptr, # Q: n x h x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + block_size, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + # q loop num + num_q_loop: tl.constexpr, + num_k_loop: tl.constexpr, + MAX_SEQ_LEN: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid = tl.program_id(0) + + Q = MAX_SEQ_LEN // num_q_loop + HK = NUM_KV_HEADS // num_k_loop + + # 第几个 (b, kh_chunk, q_chunk) + pid_b = pid // (HK * Q) + pid_kh_chunk = (pid % (HK * Q)) // Q # 每个block处理num_k_loop个KV head + pid_q = pid % Q + + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + + if pid_q * num_q_loop >= q_len: + return + real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) + + for kh_offset in range(num_k_loop): + pid_kh = pid_kh_chunk * num_k_loop + kh_offset + pid_h = pid_kh * NUM_SHARE_Q_HEADS + + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + + """Removed causal attention, which should be: + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), + axis=0, + ) + """ + # real_topk = tl.sum( + # tl.where((topk_idx >= 0), 1, 0), + # axis=0, + # ) + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // block_size), 1, 0), + axis=0, + ) + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_h = tl.arange(0, BLOCK_SIZE_H) + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_H,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_H, BLOCK_SIZE_D), 0, dtype=tl.float32) + # sparse attention + for i in range(real_topk): + # get current block start index + c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K + t_ptr_j = t_ptr_j + stride_tk + # load k + k = tl.load(tl.advance(k_ptrs, (0, c)), boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIZE_K] -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(tl.advance(v_ptrs, (c, 0)), boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + (q_start + pid_q_j) * stride_on + pid_h * stride_oh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_oh, stride_od), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + lse_ptrs = lse_ptr + (q_start + pid_q_j) * stride_ln + (pid_h + off_h) * stride_lh + tl.store(lse_ptrs, lse_i, mask=off_h < NUM_SHARE_Q_HEADS) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_o = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_o[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_o[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_o[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_o * stride_dn, delta, mask=off_o < o_len) + + +@triton.jit +def count_kernel( + x_ptr, # [num_kv_heads, total_len, topk] + y_ptr, # [num_kv_heads, total_blocks] + cu_seqlens, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + topk, + stride_xh, + stride_xn, + stride_xk, + stride_yh, + stride_yn, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_R: tl.constexpr, +): + pid_h = tl.program_id(0) + pid_b = tl.program_id(1) + # get start and len after rmpad + seq_start = tl.load(cu_seqlens + pid_b) + seq_len = tl.load(cu_seqlens + pid_b + 1) - seq_start + blocks_start = tl.load(cu_seqblocks + pid_b) + num_blocks = tl.load(cu_seqblocks + pid_b + 1) - blocks_start + # load x + off_k = tl.arange(0, BLOCK_SIZE_K) + off_n = tl.arange(0, BLOCK_SIZE_N) + x_ptr = x_ptr + pid_h * stride_xh + seq_start * stride_xn + x_ptrs = x_ptr + off_n[:, None] * stride_xn + off_k[None, :] * stride_xk + # init y + y = tl.zeros((BLOCK_SIZE_R,), dtype=tl.int32) + # loop + for i in range(0, seq_len, BLOCK_SIZE_N): + x = tl.load( + x_ptrs, + mask=(off_n < seq_len - i)[:, None] & (off_k < topk)[None, :], + other=-1, + ) + x = tl.ravel(x) + y += tl.histogram(x, BLOCK_SIZE_R) + x_ptrs += BLOCK_SIZE_N * stride_xn + # store result + off_r = tl.arange(0, BLOCK_SIZE_R) + y_ptr = y_ptr + pid_h * stride_yh + blocks_start * stride_yn + y_ptrs = y_ptr + off_r * stride_yn + tl.store(y_ptrs, y.to(y_ptr.dtype.element_ty), mask=off_r < num_blocks) + + +def count_query( + topk_idx: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + seqblocks = cu_seqblocks[1:] - cu_seqblocks[:-1] + batch_size = seqlens.shape[0] + BLOCK_SIZE_K = triton.next_power_of_2(topk) + BLOCK_SIZE_N = triton.next_power_of_2(4096 // BLOCK_SIZE_K) + BLOCK_SIZE_R = triton.next_power_of_2(seqblocks.max().item() + 2) + active_query_count = torch.zeros(num_kv_heads, cu_seqblocks[-1], dtype=torch.int32, device=topk_idx.device) + grid = (num_kv_heads, batch_size) + count_kernel[grid]( + topk_idx, + active_query_count, + cu_seqlens, + cu_seqblocks, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + active_query_count.stride(0), + active_query_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_R=BLOCK_SIZE_R, + num_warps=4, + num_stages=3, + ) + return active_query_count + + +@triton.jit +def pad_topk_idx_kernel( + t_ptr, + p_ptr, + cu_seqlens, + topk, + stride_th, + stride_tn, + stride_tk, + stride_pb, + stride_ph, + stride_pn, + stride_pk, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_start = tl.load(cu_seqlens + pid_b) + q_len = tl.load(cu_seqlens + pid_b + 1) - q_start + if BLOCK_SIZE_N * pid_n >= q_len: + return + # init prts + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + q_start * stride_tn, + shape=(q_len, topk), + strides=(stride_tn, stride_tk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph, + shape=(q_len, topk), + strides=(stride_pn, stride_pk), + offsets=(pid_n * BLOCK_SIZE_N, 0), + block_shape=(BLOCK_SIZE_N, BLOCK_SIZE_T), + order=(1, 0), + ) + # load and save + idxs = tl.load(t_ptrs, boundary_check=(0, 1)) + tl.store(p_ptrs, idxs, boundary_check=(0, 1)) + + +@triton.jit +def save_topk_idx_kernel( + p_ptr, + t_ptr, + cu_seqblocks, + cu_topk_q_count, + n_len, + stride_pb, + stride_ph, + stride_pn, + stride_th, + stride_tn, + stride_ch, + stride_cn, + BLOCK_SIZE_N: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_n = tl.program_id(2) + # get q start and len after rmpad + q_block_start = tl.load(cu_seqblocks + pid_b) + q_block_end = tl.load(cu_seqblocks + pid_b + 1) + c_start = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_start * stride_cn) + c_end = tl.load(cu_topk_q_count + pid_h * stride_ch + q_block_end * stride_cn) + c_len = c_end - c_start + if c_len <= 0: + return + if pid_n * BLOCK_SIZE_N >= c_len: + return + # init ptrs + p_ptrs = tl.make_block_ptr( + base=p_ptr + pid_b * stride_pb + pid_h * stride_ph + (n_len - c_len) * stride_pn, + shape=(c_len,), + strides=(stride_pn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + t_ptrs = tl.make_block_ptr( + base=t_ptr + pid_h * stride_th + c_start * stride_tn, + shape=(c_len,), + strides=(stride_tn,), + offsets=(pid_n * BLOCK_SIZE_N,), + block_shape=(BLOCK_SIZE_N,), + order=(0,), + ) + # load and save + idxs = tl.load(p_ptrs, boundary_check=(0,)) + tl.store(t_ptrs, idxs, boundary_check=(0,)) + + +def reorder_topk_idx( + topk_idx: torch.Tensor, + cu_topk_q_count: torch.Tensor, + cu_seqlens: torch.Tensor, + cu_seqblocks: torch.Tensor, + block_size: int, +): + num_kv_heads, total_len, topk = topk_idx.shape + batch_size = cu_seqlens.shape[0] - 1 + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + # pad shape [num_kv_heads, total_seqlen, topk] to [batch_size, num_kv_heads, max_seqlen, topk] + pad_topk_idx = torch.full( + (batch_size, num_kv_heads, max_seqlen, topk), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + BLOCK_SIZE_N = min(triton.next_power_of_2(max_seqlen), triton.next_power_of_2(8192 // BLOCK_SIZE_T)) + grid = (batch_size, num_kv_heads, triton.cdiv(max_seqlen, BLOCK_SIZE_N)) + pad_topk_idx_kernel[grid]( + topk_idx, + pad_topk_idx, + cu_seqlens, + topk, + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + pad_topk_idx.stride(0), + pad_topk_idx.stride(1), + pad_topk_idx.stride(2), + pad_topk_idx.stride(3), + BLOCK_SIZE_N=BLOCK_SIZE_N, + BLOCK_SIZE_T=BLOCK_SIZE_T, + ) + # argsort + pad_topk_q_idx = pad_topk_idx.view(batch_size, num_kv_heads, -1).argsort(-1) // topk + pad_topk_q_idx = pad_topk_q_idx.to(torch.int32) + # save as remove pad version + topk_q_idx = torch.full( + (num_kv_heads, cu_topk_q_count[:, -1].max().item()), + fill_value=-1, + device=topk_idx.device, + dtype=torch.int32, + ) + max_len = (cu_topk_q_count[:, cu_seqblocks][:, 1:] - cu_topk_q_count[:, cu_seqblocks][:, :-1]).max().item() + BLOCK_SIZE_N = min(triton.next_power_of_2(max_len), 8192) + grid = (batch_size, num_kv_heads, triton.cdiv(max_len, BLOCK_SIZE_N)) + save_topk_idx_kernel[grid]( + pad_topk_q_idx, + topk_q_idx, + cu_seqblocks, + cu_topk_q_count, + pad_topk_q_idx.shape[-1], + pad_topk_q_idx.stride(0), + pad_topk_q_idx.stride(1), + pad_topk_q_idx.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + BLOCK_SIZE_N=BLOCK_SIZE_N, + ) + return topk_q_idx + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + tq_ptr, # topk_q_idx: kh x N + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DK: sh x n x kh x d + # seqlens + cu_seqlens_q, # [batch_size + 1] + cu_seqlens_k, # [batch_size + 1] + cu_seqblocks, # [batch_size + 1] + cu_topk_q_count, # [kh, total_blocks] + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_tqh, + stride_tqn, + stride_ctqh, + stride_ctqn, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # get topk_q_idx + b_start = tl.load(cu_seqblocks + pid_b) # how many blocks before current sequence + act_q_start = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k) * stride_ctqn) + act_q_end = tl.load(cu_topk_q_count + pid_kh * stride_ctqh + (b_start + pid_k + 1) * stride_ctqn) + act_q_len = act_q_end - act_q_start + tq_ptr = tq_ptr + pid_kh * stride_tqh + act_q_start * stride_tqn + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + off_d = tl.arange(0, BLOCK_SIZE_D) + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # init ptrs + q_ptrs = q_ptr + q_start * stride_qn + pid_h * stride_qh + off_d[None, :] * stride_qd + do_ptrs = do_ptr + q_start * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod + d_ptrs = d_ptr + q_start * stride_dn + pid_h * stride_dh + lse_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + # loop for q blocks + for i in range(0, act_q_len, BLOCK_SIZE_Q): + # load + idx_q = tl.load(tq_ptr + i + off_q, mask=off_q < act_q_len - i, other=0).to(tl.int32) + q = tl.load( + q_ptrs + idx_q[:, None] * stride_qn, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + do = tl.load( + do_ptrs + idx_q[:, None] * stride_don, + mask=(off_q < act_q_len - i)[:, None] & (off_d < HEAD_DIM)[None, :], + other=0, + ) + lse = tl.load( + lse_ptrs + idx_q[:, None] * stride_ln, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + d = tl.load( + d_ptrs + idx_q[:, None] * stride_dn, + mask=(off_q < act_q_len - i)[:, None], + other=0, + ) + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(idx_q[:, None] >= off_k[None, :], float(0.0), float("-inf")) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + t_ptr, # topk_idx: kh x n x k + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + TOPK, + # q loop num + num_q_loop, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_th, + stride_tn, + stride_tk, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, + BLOCK_SIZE_H: tl.constexpr, + BLOCK_SIZE_T: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_kh = tl.program_id(1) + pid_q = tl.program_id(2) + pid_h = pid_kh * NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * num_q_loop >= q_len: + return + real_q_loop = min(num_q_loop, q_len - pid_q * num_q_loop) + for j in range(real_q_loop): + pid_q_j = pid_q * num_q_loop + j + # init topk idx pointer + off_t = tl.arange(0, BLOCK_SIZE_T) + t_ptr_j = t_ptr + (q_start + pid_q_j) * stride_tn + pid_kh * stride_th + topk_idx = tl.load(t_ptr_j + off_t * stride_tk, mask=off_t < TOPK, other=-1) + + real_topk = tl.sum( + tl.where((topk_idx >= 0) & (topk_idx <= pid_q_j // BLOCK_SIZE_K), 1, 0), + axis=0, + ) + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + (q_start + pid_q_j) * stride_qn + pid_h * stride_qh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_qh, stride_qd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + (q_start + pid_q_j) * stride_dqn + pid_h * stride_dqh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_dqh, stride_dqd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + (q_start + pid_q_j) * stride_don + pid_h * stride_doh, + shape=(NUM_SHARE_Q_HEADS, HEAD_DIM), + strides=(stride_doh, stride_dod), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + (q_start + pid_q_j) * stride_dn + pid_h * stride_dh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_dh, stride_dn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + (q_start + pid_q_j) * stride_ln + pid_h * stride_lh, + shape=(NUM_SHARE_Q_HEADS, 1), + strides=(stride_lh, stride_ln), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_H, 1), + order=(1, 0), + ) + # offsets + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_D), dtype=tl.float32) + # sparse + for i in range(real_topk): + # get current block start index + c = tl.load(t_ptr_j).to(tl.int32) * BLOCK_SIZE_K + t_ptr_j = t_ptr_j + stride_tk + # load + k = tl.load(tl.advance(k_ptrs, (c, 0)), boundary_check=(1, 0), padding_option="zero") + v = tl.load(tl.advance(v_ptrs, (0, c)), boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_H, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where((pid_q_j >= c + off_k)[None, :], 0, float("-inf")) + # [BLOCK_SIZE_H, HEAD_DIM] @ [BLOCK_SIZE_K, HEAD_DIM].T -> [BLOCK_SIZE_H, BLOCK_SIZE_K] + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _topk_sparse_attention_fwd( + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert block_size in {16, 32, 64, 128, 256} + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + # assert q_len == k_len and k_len == v_len + topk = topk_idx.shape[-1] + assert topk_idx.shape[0] == num_k_heads + assert topk_idx.shape[1] == q_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + + lse = torch.zeros(num_q_heads, q_len, dtype=torch.float32, device=q.device) + + # launch kernel + num_q_loop = num_k_loop = 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + + def grid(meta): + grid = ( + batch_size * triton.cdiv(num_k_heads, num_k_loop) * triton.cdiv(max_seqlen_q, num_q_loop), + ) + return grid + + num_warps, num_stages = get_num_warps_stages(head_dim, block_size, IS_HOPPER_GPU) + forward_kernel_orig[grid]( + q, + k, + v, + topk_idx, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + block_size, + # num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + num_q_loop=num_q_loop, + num_k_loop=num_k_loop, + MAX_SEQ_LEN=max_seqlen_q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _topk_sparse_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +): + + assert block_size in {16, 32, 64, 128, 256} + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + topk = topk_idx.shape[-1] + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + grid = (triton.cdiv(o_len, BLOCK_SIZE_O), num_o_heads) + + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # count active querys for each key block, shape: (num_k_heads, total_k_blocks) + seqlens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + seqblocks = torch.ceil(seqlens / block_size).to(torch.int32) + cu_seqblocks = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(seqblocks, dim=0), + ] + ).to(torch.int32) + + topk_q_count = count_query(topk_idx, cu_seqlens_q, cu_seqblocks, block_size) + + cu_topk_q_count = torch.cat( + [ + torch.zeros(topk_q_count.shape[0], 1, dtype=torch.int32, device=topk_idx.device), + torch.cumsum(topk_q_count, dim=-1), + ], + dim=-1, + ).to(torch.int32) + # active query idx for each key block + # how to get active query idx for sequence b, head h, kv block i? + topk_q_idx = reorder_topk_idx(topk_idx, cu_topk_q_count, cu_seqlens_q, cu_seqblocks, block_size) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + BLOCK_SIZE_K = triton.next_power_of_2(block_size) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + grid = (batch_size, num_q_heads, triton.cdiv(max_seqlen_k, BLOCK_SIZE_K)) + backward_dkdv[grid]( + q, + k, + v, + topk_q_idx, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + cu_seqblocks, + cu_topk_q_count, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_q_idx.stride(0), + topk_q_idx.stride(1), + cu_topk_q_count.stride(0), + cu_topk_q_count.stride(1), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + num_q_loop = max_seqlen_q // 32768 + 1 # calculate multiple querys in one kernel if seqlence length is too long + grid = (batch_size, num_k_heads, triton.cdiv(max_seqlen_q, num_q_loop)) + BLOCK_SIZE_K = block_size + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_H = max(16, triton.next_power_of_2(num_share_q_heads)) + BLOCK_SIZE_T = triton.next_power_of_2(topk) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + + backward_dq[grid]( + q, + k, + v, + topk_idx, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + topk, + num_q_loop, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + topk_idx.stride(0), + topk_idx.stride(1), + topk_idx.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + BLOCK_SIZE_H=BLOCK_SIZE_H, + BLOCK_SIZE_T=BLOCK_SIZE_T, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class TopkSparseAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, # [total_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_len, num_k_heads, head_dim] + v: torch.Tensor, # [total_len, num_k_heads, head_dim] + topk_idx: torch.Tensor, # [num_kv_heads, total_len, topk] + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert topk_idx.dtype == torch.int32 + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + o, lse = _topk_sparse_attention_fwd( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.block_size = block_size + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k, topk_idx = ctx.saved_tensors + + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + block_size = ctx.block_size + assert block_size in {16, 32, 64, 128, 256} + + dq, dk, dv = _topk_sparse_attention_bwd( + o, + do, + lse, + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None, None + + +def topk_sparse_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention varlen version implemented in triton. + + Args: + q (torch.Tensor): shape [total_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_len, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_len, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens in flash_attn_func_varlen. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_len, num_q_heads, head_dim] + """ + + max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() + return TopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + softmax_scale, + ) + + +def topk_sparse_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + topk_idx: torch.Tensor, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + softmax_scale: Optional[float] = None, +) -> torch.Tensor: + """Topk sparse attention with separate Q and K sequence lengths (for extend/prefill). + + Same as topk_sparse_attention but accepts separate cu_seqlens for Q and K. + Useful when Q only covers new tokens while K covers all tokens (prefix + new). + + Args: + q (torch.Tensor): shape [total_q, num_q_heads, head_dim] + k (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_k, num_kv_heads, head_dim] + topk_idx (torch.Tensor): topk block idx for each query, shape [num_kv_heads, total_q, topk]. -1 means padding. + block_size (int): key value block size. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], cumulative Q sequence lengths. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], cumulative K sequence lengths. + softmax_scale (Optional[float], optional): Defaults to None, means 1/sqrt(head_dim). + + Returns: + torch.Tensor: attention output, shape [total_q, num_q_heads, head_dim] + """ + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + return TopkSparseAttention.apply( + q, + k, + v, + topk_idx, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + softmax_scale, + ) diff --git a/vortex_torch/flow/external_algorithms.py b/vortex_torch/flow/external_algorithms.py new file mode 100644 index 00000000..5f8935fa --- /dev/null +++ b/vortex_torch/flow/external_algorithms.py @@ -0,0 +1,76 @@ +""" +External sparse attention algorithm registrations for NSA, FSA, and FlashMoBA. + +These vFlow subclasses use simple centroid-based routing for the DECODE path +(forward_indexer + forward_cache), identical to BlockSparseAttention. + +The EXTEND path (forward_extend) is handled directly in vtx_graph_backend.py +using each algorithm's own sparse attention kernel — these vFlow classes are +not involved in extend. +""" + +import torch +from typing import Dict, Tuple + +from .flow import vFlow +from ..indexer import topK, GeMV +from ..cache import Mean as CMean +from ..abs import ContextBase +from .registry import register + + +class _ExternalAlgoBase(vFlow): + """ + Base vFlow for external sparse attention algorithms (NSA, FSA, FlashMoBA). + + Decode routing: centroid-based (same as BlockSparseAttention). + Extend: bypassed — vtx_graph_backend dispatches to algorithm-specific kernels. + """ + + def __init__(self): + super().__init__() + self.gemv = GeMV() + self.output_func = topK() + self.reduction = CMean(dim=1) + + def forward_indexer( + self, + q: torch.Tensor, + o: torch.Tensor, + cache: Dict[str, torch.Tensor], + ctx: ContextBase, + ): + q_mean = q.mean(dim=1, keepdim=True) + score = self.gemv(q_mean, cache["centroids"], ctx=ctx) + self.output_func(score, o, ctx=ctx) + + def forward_cache( + self, + cache: Dict[str, torch.Tensor], + loc: torch.Tensor, + ctx: ContextBase, + ): + self.reduction(cache["k"], cache["centroids"], loc=loc, ctx=ctx) + + def create_cache(self, page_size: int, head_dim: int) -> Dict[str, Tuple[int, int]]: + return { + "centroids": (1, head_dim), + } + + +@register("nsa") +class NSASparseAttention(_ExternalAlgoBase): + """Naive Sparse Attention — decode uses centroid routing, extend uses NSA kernels.""" + pass + + +@register("fsa") +class FSASparseAttention(_ExternalAlgoBase): + """Flash Sparse Attention — decode uses centroid routing, extend uses FSA kernels.""" + pass + + +@register("flash_moba") +class FlashMoBASparseAttention(_ExternalAlgoBase): + """FlashMoBA — decode uses centroid routing, extend uses FlashMoBA kernels.""" + pass diff --git a/vortex_torch/kernels/__init__.py b/vortex_torch/kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vortex_torch/kernels/fsa/__init__.py b/vortex_torch/kernels/fsa/__init__.py new file mode 100644 index 00000000..25d5b3eb --- /dev/null +++ b/vortex_torch/kernels/fsa/__init__.py @@ -0,0 +1,5 @@ +from .fused_score_kernels import _fused_attention_score_and_transform + +__all__ = [ + "_fused_attention_score_and_transform", +] diff --git a/vortex_torch/kernels/fsa/fused_score_kernels.py b/vortex_torch/kernels/fsa/fused_score_kernels.py new file mode 100644 index 00000000..f2a05ed8 --- /dev/null +++ b/vortex_torch/kernels/fsa/fused_score_kernels.py @@ -0,0 +1,300 @@ +# This file provides a fused implementation of computing attention score for selected attention indices. +# TODO: this implementation may incur illegal memory access issues, will be fixed. +import math + +import torch +import triton +import triton.language as tl + +from ..nsa.utils import is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def fused_score_kernel( + q_ptr, # q_len x h x d + k_ptr, # k_len x h x d + lse_ptr, # h x n + bs_ptr, # h x n x nb + offs_ptr, # BO + kernel_size, + kernel_stride, + num_offs, # BO + num_k_blocks, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, # which is also num_q_heads + HEAD_DIM, + # sm_scale + sm_scale, + max_blocks, + pad_len, + block_size, + block_stride, + init_blocks, + local_blocks, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_bsh, + stride_bsq, + stride_bsnb, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) # the blocks id of k + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + + k_start += pid_k * BLOCK_SIZE_K * num_k_blocks + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_kh * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + + for j in range(num_k_blocks): + k_start_j = k_start + j * BLOCK_SIZE_K + if k_start_j < k_len: + off_d = tl.arange(0, BLOCK_SIZE_D) + off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + # k offsets + off_k = (k_start_j + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len + k_ptrs = k_ptr + pid_kh * stride_kh + off_k[None, :] * stride_kn + off_d[:, None] * stride_kd + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + + # init block score + bs = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + for i in range(num_offs): + k = tl.load(k_ptrs, mask=causal_mask, other=0) + w = tl.load(offs_ptr + i, mask=i < num_offs, other=0) + # compute qk + qk = tl.dot(q, k) * qk_scale + # compute score and apply weight + bs += w * tl.where(causal_mask, tl.exp2(qk - lse), 0) + + # increment pointers + off_k += 1 + k_ptrs = k_ptr + pid_kh * stride_kh + off_k[None, :] * stride_kn + off_d[:, None] * stride_kd + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + + # init mask and local mask + off_bq = off_q // block_size + off_bk = tl.arange(0, BLOCK_SIZE_K) + bs = tl.where( + ( + (off_bq[:, None] >= k_start_j + off_bk[None, :]) + & (off_bq[:, None] < k_start_j + off_bk[None, :] + local_blocks) + ) + | (off_bk[None, :] < init_blocks - k_start_j), + float("inf"), + bs, + ) + + # save output + bs_ptrs = ( + bs_ptr + + pid_kh.to(tl.int64) * stride_bsh + + q_start * stride_bsq + + k_start_j * stride_bsnb + + off_q[:, None] * stride_bsq + + off_bk[None, :] * stride_bsnb + ) + + tl.store( + bs_ptrs, + bs.to(bs_ptr.dtype.element_ty), + mask=(off_q < q_len)[:, None] & (off_bk < max_blocks - k_start_j)[None, :], + ) + + +def _fused_attention_score_and_transform( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + kernel_size: int, + kernel_stride: int, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + init_blocks: int = 1, + local_blocks: int = 2, + align_baseline: bool = False, +) -> torch.Tensor: + + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + max_blocks = math.ceil(max_seqlen_q / block_size) + # init block score + block_scores = torch.zeros( + num_k_heads, + q_len, + max_blocks, + dtype=torch.float32 if align_baseline else torch.bfloat16, + device=q.device, + ) + offs = ( + torch.arange(kernel_size // kernel_stride, device=q.device)[:, None] + + torch.arange(block_size // kernel_stride, device=q.device)[None, :] + ).view(-1) + + offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) + + num_offs = int(offs.shape[0]) + for i in range(cu_seqlens_q.shape[0] - 1): + q_seq = q[cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + k_seq = k[cu_seqlens_k[i]: cu_seqlens_k[i + 1]] + lse_seq = lse[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + block_scores_seq = block_scores[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] + + _fused_attention_score_and_transform_per_seq( + q_seq, + k_seq, + lse_seq, + block_scores_seq, + kernel_size, + kernel_stride, + block_size, + offs, + num_offs, + cu_seqlens_q[i: i + 2] - cu_seqlens_q[i], + cu_seqlens_k[i: i + 2] - cu_seqlens_k[i], + cu_seqlens_q[i + 1] - cu_seqlens_q[i], + cu_seqlens_k[i + 1] - cu_seqlens_k[i], + sm_scale, + init_blocks, + local_blocks, + ) + block_scores[:, cu_seqlens_q[i]: cu_seqlens_q[i + 1]] = block_scores_seq + return block_scores + + +@torch.inference_mode() +def _fused_attention_score_and_transform_per_seq( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + block_score: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + offs, + num_offs, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, + init_blocks: int = 1, + local_blocks: int = 2, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + + max_blocks = math.ceil(max_seqlen_q / block_size) + + pad_len = kernel_size // kernel_stride - 1 + max_blocks = math.ceil(max_seqlen_q / block_size) + + BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) + # ensure qk is valid on triton + BLOCK_SIZE_K = max(BLOCK_SIZE_K, 16) + BLOCK_SIZE_Q = 128 + + # launch kernel + num_k_blocks = 1 + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, BLOCK_SIZE_Q), + triton.cdiv(max_blocks, BLOCK_SIZE_K * num_k_blocks), + ) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + fused_score_kernel[grid]( + q, + k, + lse, + block_score, + offs, + kernel_size, + kernel_stride, + num_offs, + num_k_blocks, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + head_dim, + sm_scale, + max_blocks, + pad_len, + block_size, + block_size // kernel_stride, + init_blocks, + local_blocks, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + block_score.stride(0), + block_score.stride(1), + block_score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) diff --git a/vortex_torch/kernels/nsa/__init__.py b/vortex_torch/kernels/nsa/__init__.py new file mode 100644 index 00000000..9af30295 --- /dev/null +++ b/vortex_torch/kernels/nsa/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .compressed_attention import compressed_attention +from .weighted_pool import (avgpool_compress, softmaxpool_compress, + weightedpool_compress) + +__all__ = [ + "compressed_attention", + "avgpool_compress", + "weightedpool_compress", + "softmaxpool_compress", +] diff --git a/vortex_torch/kernels/nsa/compressed_attention.py b/vortex_torch/kernels/nsa/compressed_attention.py new file mode 100644 index 00000000..9770a942 --- /dev/null +++ b/vortex_torch/kernels/nsa/compressed_attention.py @@ -0,0 +1,1317 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import warnings +from typing import Any, Tuple, Union + +import torch +import triton +import triton.language as tl + +from .utils import get_num_warps_stages, is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + v_ptr, # V: n x h x d + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # size and stride at compresstion + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # skip first kernel_size query block, because they do no attend to any keys + q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 + if q_start_in_seq >= q_len: + return + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq + off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + # attention + lo = 0 + hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) + for i in range(lo, hi, BLOCK_SIZE_K): + i = tl.multiple_of(i, BLOCK_SIZE_K) + # load k + k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.exp2(lse_i - m_ij) + l_ij) + # update ptrs + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # final scale + acc_o = acc_o * tl.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, HEAD_DIM), + strides=(stride_on, stride_od), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln + tl.store(l_ptrs, lse_i, mask=off_q < q_len) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len) + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DV: sh x n x kh x d + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = pid_k * BLOCK_SIZE_K * kernel_stride + tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + q_lo = pid_k * BLOCK_SIZE_K * kernel_stride + kernel_size - 1 + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(HEAD_DIM, q_len), + strides=(stride_qd, stride_qn), + offsets=(0, q_lo), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(HEAD_DIM, q_len), + strides=(stride_dod, stride_don), + offsets=(0, q_lo), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_Q), + order=(0, 1), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(1, q_len), + strides=(0, stride_dn), + offsets=(0, q_lo), + block_shape=(1, BLOCK_SIZE_Q), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(1, q_len), + strides=(0, stride_ln), + offsets=(0, q_lo), + block_shape=(1, BLOCK_SIZE_Q), + order=(0, 1), + ) + # loop for q blocks + for i in range(q_lo, q_len, BLOCK_SIZE_Q): + # load + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + qk = tl.where(off_k[:, None] <= (off_q + i)[None, :], float(0.0), float("-inf")) + qk += tl.dot(k, q) * qk_scale + # compute p, ds + # [BLOCK_SIZE_K, BLOCK_SIE_Q] - [1, BLOCK_SIZE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + p = tl.exp2(qk - lse) + # [BLOCK_SIZE_K, HEAD_DIM] @ [HEAD_DIM, BLOCK_SIE_Q] -> [BLOCK_SIZE_K, BLOCK_SIE_Q] + dp = tl.dot(v, do) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + # [BLOCK_SIZE_K, BLOCK_SIE_Q] @ [BLOCK_SIE_Q, HEAD_DIM] -> [BLOCK_SIZE_K, HEAD_DIM] + dk += tl.dot(ds, tl.trans(q)) + dv += tl.dot(p, tl.trans(do)) + # increment pointers + q_ptrs = tl.advance(q_ptrs, (0, BLOCK_SIZE_Q)) + do_ptrs = tl.advance(do_ptrs, (0, BLOCK_SIZE_Q)) + lse_ptrs = tl.advance(lse_ptrs, (0, BLOCK_SIZE_Q)) + d_ptrs = tl.advance(d_ptrs, (0, BLOCK_SIZE_Q)) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + # skip first kernel_size query block, because they do no attend to any keys + q_start_in_seq = pid_q * BLOCK_SIZE_Q + kernel_size - 1 + if q_start_in_seq >= q_len: + return + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, + shape=(q_len, HEAD_DIM), + strides=(stride_dqn, stride_dqd), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(HEAD_DIM, k_len), + strides=(stride_vd, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(q_start_in_seq, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + q_start_in_seq + off_k = tl.arange(0, BLOCK_SIZE_K) * kernel_stride + kernel_size - 1 + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + lo = 0 + hi = min(k_len, (q_start_in_seq + BLOCK_SIZE_Q - kernel_size) // kernel_stride + 1) + for i in range(lo, hi, BLOCK_SIZE_K): + # load + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.where(off_q[:, None] >= (i * kernel_stride + off_k)[None, :], 0, float("-inf")) + qk += tl.dot(q, tl.trans(k)) * qk_scale + # compute p, ds + p = tl.exp2(qk - lse) + dp = tl.dot(do, v) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # increment pointers + k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) + v_ptrs = tl.advance(v_ptrs, (0, BLOCK_SIZE_K)) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _compressed_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale: float, +): + # dtype check + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert k_len == v_len and q_len > k_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.zeros_like(q) + lse = torch.full( + (num_q_heads, q_len), + fill_value=-torch.inf, + dtype=torch.float32, + device=q.device, + ) + # launch kernel + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + forward_kernel[grid]( + q, + k, + v, + o, + lse, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _compressed_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale: float, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + # compute D + delta = torch.zeros([num_o_heads, o_len], device=o.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # compute dk dv + dk = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.zeros(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + backward_dkdv[grid]( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.zeros_like(q) + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + backward_dq[grid]( + q, + k, + v, + lse, + delta, + do, + dq, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class CompressedAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + sm_scale=None, + ): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype and k.dtype == v.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + + o, lse = _compressed_attention_fwd( + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.kernel_size = kernel_size + ctx.kernel_stride = kernel_stride + return o, lse + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + kernel_size = ctx.kernel_size + kernel_stride = ctx.kernel_stride + + dq, dk, dv = _compressed_attention_bwd( + o, + do, + lse, + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +@triton.jit +def score_kernel( + q_ptr, + k_ptr, + lse_ptr, + s_ptr, + kernel_size, + kernel_stride, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_lh, + stride_ln, + stride_sh, + stride_sq, + stride_sk, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_bkh = tl.program_id(0) + pid_b = pid_bkh // NUM_KV_HEADS + pid_kh = pid_bkh % NUM_KV_HEADS + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if pid_q * BLOCK_SIZE_Q >= q_len or pid_k * BLOCK_SIZE_K >= k_len: + return + # init k pointer and load k + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + causal_mask = off_q[:, None] >= (off_k * kernel_stride + kernel_size - 1)[None, :] + # init score + s = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_kh * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_kh * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # load q and lse + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k) * qk_scale + # compute score + s += tl.where(causal_mask, tl.exp2(qk - lse), 0) + # save output + s_ptrs = tl.make_block_ptr( + base=s_ptr + pid_kh * stride_sh + q_start * stride_sq, + shape=(q_len, k_len), + strides=(stride_sq, stride_sk), + offsets=(pid_q * BLOCK_SIZE_Q, pid_k * BLOCK_SIZE_K), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_K), + order=(1, 0), + ) + tl.store(s_ptrs, s.to(s_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_attention_score( + q: torch.Tensor, # [total_query_len, num_q_heads, head_dim] + k: torch.Tensor, # [total_key_len, num_k_heads, head_dim] + lse: torch.Tensor, # [num_q_heads, total_query_len] + kernel_size: int, + kernel_stride: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float, +) -> torch.Tensor: + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert q.dtype == k.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + assert lse.dtype == torch.float32 # lse here is log2(sum(exp(qk*scale))), not log(sum(exp(qk*scale))) + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + batch_size = cu_seqlens_q.shape[0] - 1 + assert q_len > k_len + if sm_scale is None: + sm_scale = 1 / math.sqrt(head_dim) + # gqa + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # init score + score = torch.zeros(num_k_heads, q_len, max_seqlen_k, dtype=torch.float32, device=q.device) + + # launch kernel + grid = lambda META: ( + batch_size * num_k_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 128 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + + score_kernel[grid]( + q, + k, + lse, + score, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + lse.stride(0), + lse.stride(1), + score.stride(0), + score.stride(1), + score.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=8, + num_stages=3, + ) + return score + + +@triton.jit +def _transform_score_kernel( + s_ptr, # score, shape: [num_heads, q_len, k_len] + bs_ptr, # block wise score: [num_heads, q_len, num_k_block] + offs, + cu_seqlens_q, + # shape + num_heads, + num_offs, + max_k_len, + max_blocks, + pad_len, + # kernel & block size + block_size, + block_stride, # block_size // kernel_stride + init_blocks, + local_blocks, + # stride + stride_sh, + stride_sq, + stride_sk, + stride_bsh, + stride_bsq, + stride_bsk, + TOTAL_QUERY_LEN: tl.constexpr, + BLOCK_SIZE_Q: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_O: tl.constexpr, +): + pid_bh = tl.program_id(0) + pid_b = pid_bh // num_heads + pid_h = pid_bh % num_heads + pid_q = tl.program_id(1) + pid_k = tl.program_id(2) + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = pid_k * BLOCK_SIZE_K + if pid_q * BLOCK_SIZE_Q >= q_len: + return + # load weight + off_o = tl.arange(0, BLOCK_SIZE_O) + w = tl.load(offs + off_o, mask=off_o < num_offs, other=0) + # load score + off_q = pid_q * BLOCK_SIZE_Q + tl.arange(0, BLOCK_SIZE_Q) + off_k = (k_start + tl.arange(0, BLOCK_SIZE_K)) * block_stride - pad_len + off_k = off_k[None, :] + off_o[:, None] + s_ptrs = ( + s_ptr + + q_start * stride_sq + + pid_h * stride_sh + + off_q[:, None, None] * stride_sq + + off_k[None, :, :] * stride_sk + ) + # weighted sum, [BQ, BO, BK] * [1, BO, 1] -> [BQ, BO, BK] -> [BQ, BK] + s = tl.load( + s_ptrs, + mask=(off_q < q_len)[:, None, None] & (off_k >= 0) & (off_k < max_k_len), + other=0, + ) + s = s * w[None, :, None] + s = tl.sum(s, axis=1) + # init mask and local mask + off_bq = off_q // block_size + off_bk = k_start + tl.arange(0, BLOCK_SIZE_K) + s = tl.where( + ((off_bq[:, None] >= off_bk[None, :]) & (off_bq[:, None] < off_bk[None, :] + local_blocks)) + | (off_bk[None, :] < init_blocks - k_start), + float("inf"), + s, + ) + # store block wise score + bs_ptrs = ( + bs_ptr + q_start * stride_bsq + pid_h * stride_bsh + off_q[:, None] * stride_bsq + off_bk[None, :] * stride_bsk + ) + tl.store( + bs_ptrs, + s, + mask=(off_q < q_len)[:, None] & (off_bk < max_blocks)[None, :], + ) + + +def transform_score( + score: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + init_blocks: int = 1, + local_blocks: int = 2, +) -> torch.Tensor: + num_k_heads, total_query_len, max_key_len = score.shape + batch_size = cu_seqlens_q.shape[0] - 1 + pad_len = kernel_size // kernel_stride - 1 + max_blocks = math.ceil(max_seqlen_q / block_size) + block_score = torch.zeros( + num_k_heads, + total_query_len, + max_blocks, + dtype=torch.float32, + device=score.device, + ) + offs = ( + torch.arange(kernel_size // kernel_stride, device=score.device)[:, None] + + torch.arange(block_size // kernel_stride, device=score.device)[None, :] + ).view(-1) + + offs = torch.histc(offs, bins=offs.max() + 1, min=0, max=offs.max()) + + num_offs = int(offs.shape[0]) + + BLOCK_SIZE_Q = 16 + BLOCK_SIZE_K = min(128, triton.next_power_of_2(max_blocks)) + BLOCK_SIZE_O = triton.next_power_of_2(num_offs) + + def grid(meta): + grid = ( + num_k_heads * batch_size, + triton.cdiv(total_query_len, BLOCK_SIZE_Q), + triton.cdiv(max_blocks, BLOCK_SIZE_K), + ) + return grid + + _transform_score_kernel[grid]( + score, + block_score, + offs, + cu_seqlens_q, + num_k_heads, + offs.shape[0], + max_key_len, + max_blocks, + pad_len, + block_size, + block_size // kernel_stride, + init_blocks, + local_blocks, + score.stride(0), + score.stride(1), + score.stride(2), + block_score.stride(0), + block_score.stride(1), + block_score.stride(2), + TOTAL_QUERY_LEN=total_query_len, + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_O=BLOCK_SIZE_O, + num_warps=4, + num_stages=3, + ) + return block_score + + +def compressed_attention( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kernel_size: int, + kernel_stride: int, + block_size: int, + topk: int, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sm_scale: float = None, + init_blocks: int = 1, + local_blocks: int = 2, + parallel_topk_compute: Union[str, bool] = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Attention between query and compressed key and value. Compute attention output and topk block idx used in topk_sparse_attention. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_kv_heads, head_dim] + kernel_size (int): kernel size in compress_key_value + kernel_stride (int): stride of compress_key_value + block_size (int): key value block size for topk sparse attention. + topk (int): number of blocks for each query. + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (int): max q len of the batch. + max_seqlen_k (int): max k len of the batch. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + init_blocks (int, optional): Number of init blocks for each query. Defaults to 1. + local_blocks (int, optional): Number of local blocks for each query. Defaults to 2. + parallel_topk_compute (str, optional): Only set it to False when the sequence length is too long. This can avoid a current bug. + We'll fix this issue later. Defaults to auto, it will be set to False when the sequence length is greater than 32k and True otherwise. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: attention output and topk_idx used in topk_sparse_attention + """ + + if max_seqlen_q is None: + max_seqlen_q = (cu_seqlens_q[1:] - cu_seqlens_q[:-1]).max().item() + if max_seqlen_k is None: + max_seqlen_k = (cu_seqlens_k[1:] - cu_seqlens_k[:-1]).max().item() + + attn_output, lse = CompressedAttention.apply( + q, + k, + v, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + + # do not select topk index + if topk <= 0: + warnings.warn("topk <= 0, returned topk_idx will be None") + return attn_output, None + + assert topk >= init_blocks + local_blocks + with torch.no_grad(): + num_k_heads, num_q_heads = k.shape[1], q.shape[1] + num_shared_q_heads = num_q_heads // num_k_heads + batch_size = cu_seqlens_q.shape[0] - 1 + q_idx = torch.cat( + [torch.arange(cu_seqlens_q[i + 1] - cu_seqlens_q[i], device=q.device) for i in range(batch_size)], + dim=0, + ) + q_idx = q_idx // block_size + + # whether to use parallel version + if parallel_topk_compute == "auto": + parallel_topk_compute = cu_seqlens_q[-1] <= 32768 + # parallel version + if parallel_topk_compute: + # recompute score + score = _get_attention_score( + q, + k, + lse, + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + # transform score to block-wise score + score = transform_score( + score, + kernel_size, + kernel_stride, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + init_blocks, + local_blocks, + ) + # get topk + topk = min(topk, score.shape[-1]) + topk_idx = score.topk(topk, dim=-1).indices.sort(-1).values + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + # non parallel version, avoid some current bugs when sequence length is too long + # FIXME: need to fix later + else: + topk_idx_list = [] + head_tile = 1 + assert num_k_heads % head_tile == 0, f"Num kv heads: {num_k_heads}, head_tile: {head_tile}" + for h in range(num_k_heads // head_tile): + # recompute score + score = _get_attention_score( + q[:, h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile], + k[:, h * head_tile: (h + 1) * head_tile], + lse[h * num_shared_q_heads * head_tile: (h + 1) * num_shared_q_heads * head_tile], + kernel_size, + kernel_stride, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sm_scale, + ) + # transform score to block-wise score + score = transform_score( + score, + kernel_size, + kernel_stride, + block_size, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + init_blocks, + local_blocks, + ) + # get topk + topk = min(topk, score.shape[-1]) + if score.dtype == torch.float32: + score = score.to(torch.bfloat16) + topk_idx = score.topk(topk, dim=-1, sorted=False).indices + topk_idx = topk_idx.sort(-1).values + + topk_idx[topk_idx > q_idx[None, :, None]] = -1 + topk_idx = topk_idx.to(torch.int32) + topk_idx_list.append(topk_idx) + topk_idx = torch.cat(topk_idx_list, dim=0) + + return attn_output, topk_idx diff --git a/vortex_torch/kernels/nsa/flash_attention.py b/vortex_torch/kernels/nsa/flash_attention.py new file mode 100644 index 00000000..c556a4c4 --- /dev/null +++ b/vortex_torch/kernels/nsa/flash_attention.py @@ -0,0 +1,886 @@ +# Copyright 2025 Xunhao Lai & Jianqiao Lu. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Any, Optional + +import torch +import triton +import triton.language as tl + +from .utils import get_num_warps_stages, is_hopper_gpu + +IS_HOPPER_GPU = is_hopper_gpu() + + +@triton.jit +def forward_kernel( + q_ptr, # Q: n x h x d + k_ptr, # K: n x h x d + v_ptr, # V: n x h x d + o_ptr, # O: n x h x d + lse_ptr, # LSE: h x n + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_on, + stride_oh, + stride_od, + stride_lh, + stride_ln, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + if gqa_interleave: + pid_kh = pid_h % NUM_KV_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_Q * pid_q >= q_len: + return + # init qkv pointer + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(HEAD_DIM, k_len), + strides=(stride_kd, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # load q + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + # init statistics + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + m_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + lse_i = tl.full((BLOCK_SIZE_Q,), float("-inf"), dtype=tl.float32) + acc_o = tl.full((BLOCK_SIZE_Q, BLOCK_SIZE_D), 0, dtype=tl.float32) + # full attention or causal attention + lo = 0 + if causal: + hi = min(k_len, (pid_q + 1) * BLOCK_SIZE_Q) + else: + hi = k_len + for i in range(lo, hi, BLOCK_SIZE_K): + i = tl.multiple_of(i, BLOCK_SIZE_K) + # load k + k = tl.load(k_ptrs, boundary_check=(1, 0), padding_option="zero") + # compute qk + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + if causal: + qk += tl.where(off_q[:, None] >= (i + off_k)[None, :], 0, float("-inf")) + else: + qk += tl.where((off_k < k_len - i)[None, :], 0, float("-inf")) + qk += tl.dot(q, k) * qk_scale + # compute m_ij and l_ij + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + p = tl.math.exp2(qk - m_ij[:, None]) + l_ij = tl.sum(p, axis=1) + # scale acc_o + acc_o_scale = tl.math.exp2(m_i - m_ij) + acc_o = acc_o * acc_o_scale[:, None] + # load v and update acc_o + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + # update statistics + m_i = m_ij + lse_i = m_ij + tl.math.log2(tl.math.exp2(lse_i - m_ij) + l_ij) + # update ptrs + k_ptrs = tl.advance(k_ptrs, (0, BLOCK_SIZE_K)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # final scale + acc_o = acc_o * tl.math.exp2(m_i - lse_i)[:, None] + # save output + o_ptrs = tl.make_block_ptr( + base=o_ptr + q_start * stride_on + pid_h * stride_oh, + shape=(q_len, HEAD_DIM), + strides=(stride_on, stride_od), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + tl.store(o_ptrs, acc_o.to(o_ptr.dtype.element_ty), boundary_check=(0, 1)) + # save lse + l_ptrs = lse_ptr + q_start * stride_ln + pid_h * stride_lh + off_q * stride_ln + tl.store(l_ptrs, lse_i, mask=off_q < q_len) + + +@triton.jit +def backward_sum_o_do( + o_ptr, # O: n x h x d + do_ptr, # dO: n x h x d + delta_ptr, # D: h x n + o_len, + HEAD_DIM, + stride_on, + stride_oh, + stride_od, + stride_don, + stride_doh, + stride_dod, + stride_dh, + stride_dn, + BLOCK_SIZE_O: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_n = tl.program_id(0) + pid_h = tl.program_id(1) + off_n = pid_n * BLOCK_SIZE_O + tl.arange(0, BLOCK_SIZE_O) + off_d = tl.arange(0, BLOCK_SIZE_D) + o = tl.load( + o_ptr + off_n[:, None] * stride_on + pid_h * stride_oh + off_d[None, :] * stride_od, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + do = tl.load( + do_ptr + off_n[:, None] * stride_don + pid_h * stride_doh + off_d[None, :] * stride_dod, + mask=(off_n[:, None] < o_len) & (off_d[None, :] < HEAD_DIM), + other=0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + tl.store(delta_ptr + pid_h * stride_dh + off_n * stride_dn, delta, mask=off_n < o_len) + + +@triton.jit +def backward_dkdv( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dk_ptr, # DK: sh x n x kh x d + dv_ptr, # DV: sh x n x kh x d + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dks, + stride_dkn, + stride_dkh, + stride_dkd, + stride_dvs, + stride_dvn, + stride_dvh, + stride_dvd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + if gqa_interleave: + pid_kh = pid_h % NUM_SHARE_Q_HEADS + pid_sh = pid_h // NUM_SHARE_Q_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + pid_sh = pid_h % NUM_SHARE_Q_HEADS + pid_k = tl.program_id(2) + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_K * pid_k >= k_len: + return + # init pointers + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dk_ptrs = tl.make_block_ptr( + base=dk_ptr + k_start * stride_dkn + pid_kh * stride_dkh + pid_sh * stride_dks, + shape=(k_len, HEAD_DIM), + strides=(stride_dkn, stride_dkd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + dv_ptrs = tl.make_block_ptr( + base=dv_ptr + k_start * stride_dvn + pid_kh * stride_dvh + pid_sh * stride_dvs, + shape=(k_len, HEAD_DIM), + strides=(stride_dvn, stride_dvd), + offsets=(pid_k * BLOCK_SIZE_K, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + off_k = tl.arange(0, BLOCK_SIZE_K) + pid_k * BLOCK_SIZE_K + # load k v and keep in SRAM + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dk dv + dk = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + dv = tl.zeros((BLOCK_SIZE_K, BLOCK_SIZE_D), dtype=tl.float32) + # causal + if causal: + q_lo = pid_k * BLOCK_SIZE_K + else: + q_lo = 0 + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(q_lo, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # loop for q blocks + for i in range(q_lo, q_len, BLOCK_SIZE_Q): + # load + q = tl.load(q_ptrs, boundary_check=(0, 1), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + if causal: + qk = tl.where((off_q + i)[:, None] >= off_k[None, :], float(0.0), float("-inf")) + else: + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.math.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + p = p.to(do.dtype) + ds = ds.to(q.dtype) + # update dk and dv + dk += tl.dot(ds.T, q) + dv += tl.dot(p.T, do) + # increment pointers + q_ptrs = tl.advance(q_ptrs, (BLOCK_SIZE_Q, 0)) + do_ptrs = tl.advance(do_ptrs, (BLOCK_SIZE_Q, 0)) + lse_ptrs = tl.advance(lse_ptrs, (BLOCK_SIZE_Q, 0)) + d_ptrs = tl.advance(d_ptrs, (BLOCK_SIZE_Q, 0)) + # save dk dv + tl.store(dk_ptrs, dk.to(dk_ptr.dtype.element_ty), boundary_check=(0, 1)) + tl.store(dv_ptrs, dv.to(dv_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def backward_dq( + q_ptr, # Q: n x qh x d + k_ptr, # K: n x kh x d + v_ptr, # V: n x kh x d + lse_ptr, # LSE: qh x n + d_ptr, # Delta: qh x n + do_ptr, + dq_ptr, + # seqlens + cu_seqlens_q, + cu_seqlens_k, + # shape + NUM_KV_HEADS, + NUM_SHARE_Q_HEADS, + HEAD_DIM, + # sm_scale + sm_scale, + # causal + causal, + # gqa + gqa_interleave, + # stride + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_lh, + stride_ln, + stride_dh, + stride_dn, + stride_don, + stride_doh, + stride_dod, + stride_dqn, + stride_dqh, + stride_dqd, + # META parameters + BLOCK_SIZE_Q: tl.constexpr, # q block size + BLOCK_SIZE_K: tl.constexpr, # k block size + BLOCK_SIZE_D: tl.constexpr, +): + qk_scale = sm_scale * 1.44269504 + # get batch id and head id + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_q = tl.program_id(2) + if gqa_interleave: + pid_kh = pid_h % NUM_KV_HEADS + else: + pid_kh = pid_h // NUM_SHARE_Q_HEADS + # get q k start and len after rmpad + q_start = tl.load(cu_seqlens_q + pid_b) + q_len = tl.load(cu_seqlens_q + pid_b + 1) - q_start + k_start = tl.load(cu_seqlens_k + pid_b) + k_len = tl.load(cu_seqlens_k + pid_b + 1) - k_start + if BLOCK_SIZE_Q * pid_q >= q_len: + return + # init pointers + q_ptrs = tl.make_block_ptr( + base=q_ptr + q_start * stride_qn + pid_h * stride_qh, + shape=(q_len, HEAD_DIM), + strides=(stride_qn, stride_qd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + dq_ptrs = tl.make_block_ptr( + base=dq_ptr + q_start * stride_dqn + pid_h * stride_dqh, + shape=(q_len, HEAD_DIM), + strides=(stride_dqn, stride_dqd), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + k_ptrs = tl.make_block_ptr( + base=k_ptr + k_start * stride_kn + pid_kh * stride_kh, + shape=(k_len, HEAD_DIM), + strides=(stride_kn, stride_kd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + v_ptrs = tl.make_block_ptr( + base=v_ptr + k_start * stride_vn + pid_kh * stride_vh, + shape=(k_len, HEAD_DIM), + strides=(stride_vn, stride_vd), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + do_ptrs = tl.make_block_ptr( + base=do_ptr + q_start * stride_don + pid_h * stride_doh, + shape=(q_len, HEAD_DIM), + strides=(stride_don, stride_dod), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, BLOCK_SIZE_D), + order=(1, 0), + ) + d_ptrs = tl.make_block_ptr( + base=d_ptr + q_start * stride_dn + pid_h * stride_dh, + shape=(q_len, 1), + strides=(stride_dn, stride_dh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + lse_ptrs = tl.make_block_ptr( + base=lse_ptr + q_start * stride_ln + pid_h * stride_lh, + shape=(q_len, 1), + strides=(stride_ln, stride_lh), + offsets=(pid_q * BLOCK_SIZE_Q, 0), + block_shape=(BLOCK_SIZE_Q, 1), + order=(0, 1), + ) + # offsets + off_q = tl.arange(0, BLOCK_SIZE_Q) + pid_q * BLOCK_SIZE_Q + off_k = tl.arange(0, BLOCK_SIZE_K) + # load q, do, lse, delta, and keep in SRAM + q = tl.load(q_ptrs, boundary_check=(1, 0), padding_option="zero") + do = tl.load(do_ptrs, boundary_check=(0, 1), padding_option="zero") + lse = tl.load(lse_ptrs, boundary_check=(0, 1), padding_option="zero") + d = tl.load(d_ptrs, boundary_check=(0, 1), padding_option="zero") + # init dq + dq = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_D), dtype=tl.float32) + # causal + if causal: + k_hi = (pid_q + 1) * BLOCK_SIZE_Q + else: + k_hi = k_len + for j in range(0, k_hi, BLOCK_SIZE_K): + # load + k = tl.load(k_ptrs, boundary_check=(0, 1), padding_option="zero") + v = tl.load(v_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute qk + if causal: + qk = tl.where(off_q[:, None] >= (off_k + j)[None, :], float(0.0), float("-inf")) + else: + qk = tl.zeros((BLOCK_SIZE_Q, BLOCK_SIZE_K), dtype=tl.float32) + qk += tl.dot(q, k.T) * qk_scale + # compute p, ds + p = tl.math.exp2(qk - lse) + dp = tl.dot(do, v.T) + ds = sm_scale * p * (dp - d) + # cast dtype + ds = ds.to(q.dtype) + # update dq + dq += tl.dot(ds, k) + # increment pointers + k_ptrs = tl.advance(k_ptrs, (BLOCK_SIZE_K, 0)) + v_ptrs = tl.advance(v_ptrs, (BLOCK_SIZE_K, 0)) + # save dq + tl.store(dq_ptrs, dq.to(dq_ptr.dtype.element_ty), boundary_check=(0, 1)) + + +def _flash_attention_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool, + sm_scale: float, + gqa_interleave: bool = False, +): + # dtype check + assert q.dtype == torch.bfloat16 or q.dtype == torch.float16 + assert k.dtype == q.dtype and v.dtype == q.dtype + assert cu_seqlens_q.dtype == torch.int32 and cu_seqlens_k.dtype == torch.int32 + # shape + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + batch_size = cu_seqlens_q.shape[0] - 1 + # assert q_len == k_len and k_len == v_len + # gqa + assert num_k_heads == num_v_heads + assert num_q_heads % num_k_heads == 0 + num_share_q_heads = num_q_heads // num_k_heads + # output tensor + o = torch.empty_like(q) + lse = torch.empty(num_q_heads, q_len, dtype=torch.float32, device=q.device) + # launch kernel + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + forward_kernel[grid]( + q, + k, + v, + o, + lse, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + lse.stride(0), + lse.stride(1), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return o, lse + + +def _flash_attention_bwd( + o: torch.Tensor, + do: torch.Tensor, + lse: torch.Tensor, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool, + sm_scale: float, + gqa_interleave: bool = False, +): + q_len, num_q_heads, head_dim = q.shape + k_len, num_k_heads, head_dim = k.shape + v_len, num_v_heads, head_dim = v.shape + o_len, num_o_heads, head_dim = o.shape + num_share_q_heads = num_q_heads // num_k_heads + # compute D + delta = torch.empty([num_o_heads, o_len], device=o.device, dtype=torch.float32) + grid = lambda META: (triton.cdiv(o_len, META["BLOCK_SIZE_O"]), num_o_heads) + BLOCK_SIZE_O = 256 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_O, IS_HOPPER_GPU) + backward_sum_o_do[grid]( + o, + do, + delta, + o_len, + head_dim, + o.stride(0), + o.stride(1), + o.stride(2), + do.stride(0), + do.stride(1), + do.stride(2), + delta.stride(0), + delta.stride(1), + BLOCK_SIZE_O=BLOCK_SIZE_O, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + # compute dk dv + dk = torch.empty(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + dv = torch.empty(num_share_q_heads, k_len, num_k_heads, head_dim, device=k.device, dtype=k.dtype) + batch_size = cu_seqlens_q.shape[0] - 1 + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_k, META["BLOCK_SIZE_K"]), + ) + BLOCK_SIZE_Q = 64 + BLOCK_SIZE_K = 64 + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_K, IS_HOPPER_GPU) + backward_dkdv[grid]( + q, + k, + v, + lse, + delta, + do, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dk.stride(0), + dk.stride(1), + dk.stride(2), + dk.stride(3), + dv.stride(0), + dv.stride(1), + dv.stride(2), + dv.stride(3), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + dk = dk.sum(0) + dv = dv.sum(0) + # compute dq + dq = torch.empty_like(q) + grid = lambda META: ( + batch_size, + num_q_heads, + triton.cdiv(max_seqlen_q, META["BLOCK_SIZE_Q"]), + ) + BLOCK_SIZE_Q = 128 + BLOCK_SIZE_K = 64 + num_warps, num_stages = get_num_warps_stages(head_dim, BLOCK_SIZE_Q, IS_HOPPER_GPU) + backward_dq[grid]( + q, + k, + v, + lse, + delta, + do, + dq, + cu_seqlens_q, + cu_seqlens_k, + num_k_heads, + num_share_q_heads, + head_dim, + sm_scale, + causal, + gqa_interleave, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + lse.stride(0), + lse.stride(1), + delta.stride(0), + delta.stride(1), + do.stride(0), + do.stride(1), + do.stride(2), + dq.stride(0), + dq.stride(1), + dq.stride(2), + BLOCK_SIZE_Q=BLOCK_SIZE_Q, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + num_warps=num_warps, + num_stages=num_stages, + ) + return dq, dk, dv + + +class FlashAttention(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal=True, + sm_scale=None, + gqa_interleave=False, + ): + # softmax scale + if sm_scale is None: + sm_scale = 1 / math.sqrt(q.shape[-1]) + o, lse = _flash_attention_fwd( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) + ctx.save_for_backward(q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k) + ctx.sm_scale = sm_scale + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.causal = causal + ctx.gqa_interleave = gqa_interleave + return o + + @staticmethod + def backward(ctx, do: torch.Tensor, *args) -> Any: + q, k, v, o, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + max_seqlen_q = ctx.max_seqlen_q + max_seqlen_k = ctx.max_seqlen_k + sm_scale = ctx.sm_scale + causal = ctx.causal + gqa_interleave = ctx.gqa_interleave + dq, dk, dv = _flash_attention_bwd( + o, + do, + lse, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) + return dq, dk, dv, None, None, None, None, None, None, None + + +def flash_attention_varlen( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: torch.Tensor, + max_seqlen_k: torch.Tensor, + causal: bool = False, + sm_scale: Optional[float] = None, + gqa_interleave: bool = False, +) -> torch.Tensor: + """Flash attention with variable length based on triton. + + Args: + q (torch.Tensor): shape [total_q_len, num_q_heads, head_dim] + k (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] + v (torch.Tensor): shape [total_kv_len, num_q_heads, head_dim] + cu_seqlens_q (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_q in flash_attn_func_varlen. + cu_seqlens_k (torch.Tensor): shape [batch_size + 1], similar to cu_seqlens_k in flash_attn_func_varlen. + max_seqlen_q (torch.Tensor): max q len of the batch. + max_seqlen_k (torch.Tensor): max k len of the batch. + causal (bool, optional): Causal mask. Defaults to False. + sm_scale (float, optional): softmax scale. Defaults to None, means 1/sqrt(head_dim). + gqa_interleave (bool, optional): GQA pattern. Defaults to False, use Llama style GQA. + + Returns: + torch.Tensor: attention output with shape [total_q_len, num_q_heads, head_dim] + """ + return FlashAttention.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + causal, + sm_scale, + gqa_interleave, + ) diff --git a/vortex_torch/kernels/nsa/utils.py b/vortex_torch/kernels/nsa/utils.py new file mode 100644 index 00000000..1f158a17 --- /dev/null +++ b/vortex_torch/kernels/nsa/utils.py @@ -0,0 +1,50 @@ +import torch + + +def is_hopper_gpu(): + if torch.cuda.is_available(): + device_capability = torch.cuda.get_device_capability(0) + major, minor = device_capability + return major == 9 + return False + + +def get_num_warps_stages(head_dim, block_size, is_hopper_gpu): + """ + Returns recommended num_warps and num_stages for a Sparse Attention kernel in Triton. + + Args: + head_dim (int): Size of the head dimension. + block_size (int): Size of the block in the attention matrix. + is_hopper_gpu (bool): True if Hopper GPU, False if Ampere GPU. + + Returns: + tuple: (num_warps, num_stages) recommended values. + """ + # Determine if head_dim and block_size exceed 64 + head_large = head_dim > 64 + block_large = block_size > 64 + + if is_hopper_gpu: + # Hopper GPU recommendations + if head_large and block_large: + num_warps = 8 + num_stages = 3 + elif head_large or block_large: + num_warps = 4 + num_stages = 3 + else: + num_warps = 2 + num_stages = 2 + else: + # Ampere GPU recommendations + if head_large and block_large: + num_warps = 8 + num_stages = 3 + elif head_large or block_large: + num_warps = 8 + num_stages = 3 + else: + num_warps = 2 + num_stages = 2 + return num_warps, num_stages diff --git a/vortex_torch/kernels/nsa/weighted_pool.py b/vortex_torch/kernels/nsa/weighted_pool.py new file mode 100644 index 00000000..abfe9d30 --- /dev/null +++ b/vortex_torch/kernels/nsa/weighted_pool.py @@ -0,0 +1,341 @@ +# Copyright 2025 Xunhao Lai. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import triton +import triton.language as tl +from einops import einsum + + +@triton.jit +def sliding_pool_fwd_kernel( + x_ptr, + y_ptr, + w_ptr, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + stride_xn, + stride_xh, + stride_xd, + stride_yn, + stride_yh, + stride_yd, + stride_wh, + stride_wk, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_k = tl.program_id(2) + # get start and len after rmpad + x_start = tl.load(cu_seqlens + pid_b) + x_len = tl.load(cu_seqlens + pid_b + 1) - x_start + y_start = tl.load(y_cu_seqlens + pid_b) + y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start + if pid_k >= y_len: + return + if w_ptr is not None: + # load w + w_ptrs = tl.make_block_ptr( + base=w_ptr + pid_h * stride_wh, + shape=(kernel_size, 1), + strides=(stride_wk, 0), + offsets=(0, 0), + block_shape=(BLOCK_SIZE_K, 1), + order=(0, 1), + ) + w = tl.load(w_ptrs, boundary_check=(0, 1), padding_option="zero") + # load x + x_ptrs = tl.make_block_ptr( + base=x_ptr + x_start * stride_xn + pid_h * stride_xh, + shape=(x_len, head_dim), + strides=(stride_xn, stride_xd), + offsets=(pid_k * kernel_stride, 0), + block_shape=(BLOCK_SIZE_K, BLOCK_SIZE_D), + order=(1, 0), + ) + x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") + # compute y + if w_ptr is not None: + y = tl.sum(x * w, axis=0) + else: + y = tl.sum(x, axis=0) / kernel_size + off_d = tl.arange(0, BLOCK_SIZE_D) + tl.store( + y_ptr + (y_start + pid_k) * stride_yn + pid_h * stride_yh + off_d * stride_yd, + y.to(y_ptr.dtype.element_ty), + mask=off_d < head_dim, + ) + + +@triton.jit +def sliding_pool_dxdw_kernel( + x_ptr, + dx_ptr, + dy_ptr, + w_ptr, + dw_ptr, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + stride_xn, + stride_xh, + stride_xd, + stride_dxn, + stride_dxh, + stride_dxd, + stride_dyn, + stride_dyh, + stride_dyd, + stride_wh, + stride_wk, + stride_dwh, + stride_dwn, + stride_dwk, + BLOCK_SIZE_K: tl.constexpr, + BLOCK_SIZE_D: tl.constexpr, +): + pid_b = tl.program_id(0) + pid_h = tl.program_id(1) + pid_k = tl.program_id(2) + # get start and len after rmpad + x_start = tl.load(cu_seqlens + pid_b) + x_len = tl.load(cu_seqlens + pid_b + 1) - x_start + y_start = tl.load(y_cu_seqlens + pid_b) + y_len = tl.load(y_cu_seqlens + pid_b + 1) - y_start + if pid_k >= y_len: + return + # offsets + off_d = tl.arange(0, BLOCK_SIZE_D) + off_k = tl.arange(0, BLOCK_SIZE_K) + if w_ptr is not None: + # load w + w_ptrs = w_ptr + pid_h * stride_wh + off_k * stride_wk + w = tl.load(w_ptrs, mask=off_k < kernel_size, other=0) + # load x + x_ptrs = tl.make_block_ptr( + base=x_ptr + x_start * stride_xn + pid_h * stride_xh, + shape=(head_dim, x_len), + strides=(stride_xd, stride_xn), + offsets=(0, pid_k * kernel_stride), + block_shape=(BLOCK_SIZE_D, BLOCK_SIZE_K), + order=(0, 1), + ) + x = tl.load(x_ptrs, boundary_check=(0, 1), padding_option="zero") + # load dy + dy_ptrs = dy_ptr + pid_h * stride_dyh + (y_start + pid_k) * stride_dyn + off_d * stride_dyd + dy = tl.load(dy_ptrs, mask=off_d < head_dim, other=0) + if w_ptr is not None: + # compute dx, [1, D] x [K, 1] -> [K, D] + dx = dy[None, :] * w[:, None] + # compute dw, [D, 1] x [D, K] -> [D, K] -> [K] + dw = tl.sum(dy[:, None] * x, axis=0) + # store dw + dw_ptrs = dw_ptr + pid_h * stride_dwh + (y_start + pid_k) * stride_dwn + off_k * stride_dwk + tl.store(dw_ptrs, dw.to(dw_ptr.dtype.element_ty), mask=off_k < kernel_size) + else: + dx = dy[None, :] / kernel_size + # store dx + dx_ptrs = ( + dx_ptr + + pid_h * stride_dxh + + (x_start + pid_k * kernel_stride + off_k[:, None]) * stride_dxn + + off_d[None, :] * stride_dxd + ) + tl.atomic_add( + dx_ptrs, + dx.to(dx_ptr.dtype.element_ty), + mask=(off_k < x_len - pid_k * kernel_stride)[:, None] & (off_d < head_dim)[None, :], + ) + + +class SlidingWindowWeightedPool(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # [num_heads, kernel_size] + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + ): + # dtype check + assert x.dtype == torch.float16 or x.dtype == torch.bfloat16 + if w is not None: + assert x.dtype == w.dtype + assert cu_seqlens.dtype == torch.int32 + # shape check + total_len, num_heads, head_dim = x.shape + batch_size = cu_seqlens.shape[0] - 1 + if w is not None: + assert w.shape[0] == num_heads + assert w.shape[1] == kernel_size + assert kernel_size % kernel_stride == 0 + assert kernel_size in {16, 32, 64, 128} + # compute seqlens after compression + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + y_seqlens = torch.floor((seqlens - kernel_size) / kernel_stride).to(torch.int32) + 1 + # corner case, if sequence_length < kernel_size, no compression for this sequence + y_seqlens[seqlens < kernel_size] = 0 + y_cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device="cuda"), + torch.cumsum(y_seqlens, dim=0), + ], + dim=0, + ).to(torch.int32) + # output buffer + y = torch.zeros(y_cu_seqlens[-1], num_heads, head_dim, dtype=x.dtype, device=x.device) + # launch kernel + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) + grid = (batch_size, num_heads, y_seqlens.max().item()) + sliding_pool_fwd_kernel[grid]( + x, + y, + w, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + x.stride(0), + x.stride(1), + x.stride(2), + y.stride(0), + y.stride(1), + y.stride(2), + w.stride(0) if w is not None else None, + w.stride(1) if w is not None else None, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + ctx.save_for_backward(x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens) + ctx.kernel_size = kernel_size + ctx.kernel_stride = kernel_stride + ctx.head_dim = head_dim + return y, y_cu_seqlens + + @staticmethod + def backward(ctx, dy, _): + x, w, seqlens, cu_seqlens, y_seqlens, y_cu_seqlens = ctx.saved_tensors + kernel_size = ctx.kernel_size + kernel_stride = ctx.kernel_stride + head_dim = ctx.head_dim + batch_size = cu_seqlens.shape[0] - 1 + num_heads = x.shape[1] + # compute dx + dx = torch.zeros_like(x, dtype=torch.float32) + if w is not None: + dw = torch.zeros( + num_heads, + y_cu_seqlens[-1], + kernel_size, + dtype=torch.float32, + device=w.device, + ) + BLOCK_SIZE_D = triton.next_power_of_2(head_dim) + BLOCK_SIZE_K = triton.next_power_of_2(kernel_size) + grid = (batch_size, num_heads, y_seqlens.max().item()) + sliding_pool_dxdw_kernel[grid]( + x, + dx, + dy, + w, + dw if w is not None else None, + cu_seqlens, + y_cu_seqlens, + head_dim, + kernel_size, + kernel_stride, + x.stride(0), + x.stride(1), + x.stride(2), + dx.stride(0), + dx.stride(1), + dx.stride(2), + dy.stride(0), + dy.stride(1), + dy.stride(2), + w.stride(0) if w is not None else None, + w.stride(1) if w is not None else None, + dw.stride(0) if w is not None else None, + dw.stride(1) if w is not None else None, + dw.stride(2) if w is not None else None, + BLOCK_SIZE_K=BLOCK_SIZE_K, + BLOCK_SIZE_D=BLOCK_SIZE_D, + ) + dx = dx.to(x.dtype) + if w is None: + dw = None + else: + dw = dw.sum(1).to(w.dtype) + return dx, dw, None, None, None + + +def weightedpool_compress( + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # [num_heads, kernel_size] + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w, cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = einsum(pe, w, "h k d, h k -> h d") + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens + + +def avgpool_compress( + x: torch.Tensor, # [total_len, num_heads, head_dim] + w: torch.Tensor, # don't need weight + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + assert w is None, "don't need additional weight for avgpool" + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w, cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = torch.mean(pe, dim=1) + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens + + +def softmaxpool_compress( + x: torch.Tensor, + w: torch.Tensor, + cu_seqlens: torch.Tensor, + kernel_size: int, + kernel_stride: int, + pe: Optional[torch.Tensor] = None, +): + y, y_cu_seqlens = SlidingWindowWeightedPool.apply(x, w.softmax(-1), cu_seqlens, kernel_size, kernel_stride) + if pe is not None: + assert pe.dtype == x.dtype and pe.device == x.device + bias = torch.mean(pe, dim=1) + y = y + bias.unsqueeze(0) + return y, y_cu_seqlens From f6ca879efa5db14b7193dbb2fa9c417df579f647 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 30 Mar 2026 05:21:24 +0000 Subject: [PATCH 12/24] Refactor on the int8 quanitzation, mainly the dequant kernels of int8 --- examples/verify_algo.sh | 2 +- examples/verify_algo_quant.sh | 4 +- examples/verify_sparse_backends.sh | 2 +- vortex_torch/cache/__init__.py | 4 +- vortex_torch/cache/triton_kernels/__init__.py | 18 +- .../cache/triton_kernels/paged_decode_int8.py | 363 ------------- .../triton_kernels/paged_prefill_int8.py | 168 ------ vortex_torch/cache/triton_kernels/set_kv.py | 495 ++++++++++++++++++ 8 files changed, 513 insertions(+), 543 deletions(-) delete mode 100644 vortex_torch/cache/triton_kernels/paged_decode_int8.py delete mode 100644 vortex_torch/cache/triton_kernels/paged_prefill_int8.py diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 0dcbe9fc..aa01fe66 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=5 +export CUDA_VISIBLE_DEVICES=6 sparse_algos=( "block_sparse_attention" diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh index c344474a..a7601de9 100644 --- a/examples/verify_algo_quant.sh +++ b/examples/verify_algo_quant.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=0 +export CUDA_VISIBLE_DEVICES=6 sparse_algos=( "block_sparse_attention" @@ -20,6 +20,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --kv-cache-dtype int8 \ + --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -34,6 +35,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --kv-cache-dtype fp8_e4m3 \ + --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/examples/verify_sparse_backends.sh b/examples/verify_sparse_backends.sh index 81b3562d..12600d08 100755 --- a/examples/verify_sparse_backends.sh +++ b/examples/verify_sparse_backends.sh @@ -1,6 +1,6 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=5 +export CUDA_VISIBLE_DEVICES=6 sparse_algos=( "nsa" diff --git a/vortex_torch/cache/__init__.py b/vortex_torch/cache/__init__.py index 8c4d0e0f..6b549054 100644 --- a/vortex_torch/cache/__init__.py +++ b/vortex_torch/cache/__init__.py @@ -29,14 +29,14 @@ from .matmul import GeMM from .elementwise import Relu, Silu, Sigmoid, Abs, Add_Mul from .elementwise_binary import Maximum, Minimum, Multiply, Add -from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_paged_int8_to_bf16_inplace +from .triton_kernels import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher, dequant_pages_to_bf16_inplace __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", "set_kv_buffer_fp8_launcher", - "dequant_paged_int8_to_bf16_inplace", + "dequant_pages_to_bf16_inplace", "Mean", "Max", "Min", "L2Norm", "GeMM", "Relu", "Silu", "Sigmoid", "Abs", "Add_Mul", diff --git a/vortex_torch/cache/triton_kernels/__init__.py b/vortex_torch/cache/triton_kernels/__init__.py index 009e728e..de4fcbdc 100644 --- a/vortex_torch/cache/triton_kernels/__init__.py +++ b/vortex_torch/cache/triton_kernels/__init__.py @@ -1,13 +1,17 @@ -from .set_kv import set_kv_buffer_launcher, set_kv_buffer_int8_launcher, set_kv_buffer_fp8_launcher -from .paged_decode_int8 import paged_decode_int8 -from .paged_prefill_int8 import dequant_paged_int8_to_bf16, dequant_paged_int8_to_bf16_inplace +from .set_kv import ( + set_kv_buffer_launcher, + set_kv_buffer_int8_launcher, + set_kv_buffer_fp8_launcher, + paged_decode, + dequant_pages_to_bf16, + dequant_pages_to_bf16_inplace, +) __all__ = [ "set_kv_buffer_launcher", "set_kv_buffer_int8_launcher", "set_kv_buffer_fp8_launcher", - "paged_decode_int8", - "dequant_paged_int8_to_bf16", - "dequant_paged_int8_to_bf16_inplace", + "paged_decode", + "dequant_pages_to_bf16", + "dequant_pages_to_bf16_inplace", ] - diff --git a/vortex_torch/cache/triton_kernels/paged_decode_int8.py b/vortex_torch/cache/triton_kernels/paged_decode_int8.py deleted file mode 100644 index 4f33cd45..00000000 --- a/vortex_torch/cache/triton_kernels/paged_decode_int8.py +++ /dev/null @@ -1,363 +0,0 @@ -""" -Custom Triton paged decode attention kernel for int8 KV cache. - -Loads int8 K/V pages with per-token float32 scales, dequantizes inline in SRAM, -and computes standard multi-head attention with online softmax. - -Adapted from SGLang's decode_attention.py for use with Vortex's paged layout -where each KV head is treated as a separate "batch" entry. -""" - -import torch -import triton -import triton.language as tl - -_MIN_BLOCK_KV = 32 - - -@triton.jit -def tanh(x): - return 2 * tl.sigmoid(2 * x) - 1 - - -@triton.jit -def _fwd_kernel_int8_stage1( - Q, # [batch, num_qo_heads, head_dim] bf16 - K_Buffer, # int8 paged: flat - V_Buffer, # int8 paged: flat - K_Scale_Buffer, # fp16: flat (one scale per token slot) - V_Scale_Buffer, # fp16: flat - sm_scale, - kv_indptr, # [batch + 1] int32, page-level - kv_indices, # page indices - last_page_len, # [batch] int32, tokens valid in last page - Att_Out, # [batch, num_qo_heads, max_kv_splits, head_dim] - Att_Lse, # [batch, num_qo_heads, max_kv_splits] - num_kv_splits, # [batch] int32 - stride_qbs, - stride_qh, - stride_buf_kbs, # stride per token in K_Buffer (= head_dim) - stride_buf_vbs, # stride per token in V_Buffer (= head_dim) - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - kv_group_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_DV: tl.constexpr, - BLOCK_N: tl.constexpr, - MIN_BLOCK_KV: tl.constexpr, - logit_cap: tl.constexpr, - Lk: tl.constexpr, - Lv: tl.constexpr, - PAGE_SIZE: tl.constexpr, -): - """ - Stage 1: For each (batch, head, kv_split), compute partial attention output and LSE. - - kv_indptr is page-level. Total tokens for batch i: - (num_pages - 1) * PAGE_SIZE + last_page_len[i] - """ - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - split_kv_id = tl.program_id(2) - - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_dv = tl.arange(0, BLOCK_DV) - mask_d = offs_d < Lk - mask_dv = offs_dv < Lv - - cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) - cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx - cur_last_page_len = tl.load(last_page_len + cur_batch) - # Correct token count accounting for partial last page - cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len - kv_splits = tl.load(num_kv_splits + cur_batch) - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - - kv_len_per_split = ( - tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV - ) - split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - - e_max = -float("inf") - e_sum = 0.0 - acc = tl.zeros([BLOCK_DV], dtype=tl.float32) - - if split_kv_end > split_kv_start: - q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) - - for start_n in range(split_kv_start, split_kv_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - mask_n = offs_n < split_kv_end - - # Convert token offsets to page_id + in-page offset - page_indices_in_seq = offs_n // PAGE_SIZE - in_page_offsets = offs_n % PAGE_SIZE - - # Load page indices from kv_indices (physical page IDs) - page_ids = tl.load( - kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, - mask=mask_n, - other=0, - ) - - # Flat token location: physical_page * PAGE_SIZE + in_page_offset - kv_loc = page_ids * PAGE_SIZE + in_page_offsets - - # Load int8 K and dequantize - offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] - k_int8 = tl.load( - K_Buffer + offs_buf_k, - mask=mask_n[:, None] & mask_d[None, :], - other=0, - ).to(tl.float32) - - k_scale = tl.load( - K_Scale_Buffer + kv_loc, - mask=mask_n, - other=1.0, - ).to(tl.float32) - k = k_int8 * k_scale[:, None] - - # Compute QK - qk = tl.sum(q[None, :] * k, 1) - qk *= sm_scale - - if logit_cap > 0: - qk = logit_cap * tanh(qk / logit_cap) - - qk = tl.where(mask_n, qk, float("-inf")) - - # Load int8 V and dequantize - offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] - v_int8 = tl.load( - V_Buffer + offs_buf_v, - mask=mask_n[:, None] & mask_dv[None, :], - other=0, - ).to(tl.float32) - - v_scale = tl.load( - V_Scale_Buffer + kv_loc, - mask=mask_n, - other=1.0, - ).to(tl.float32) - v = v_int8 * v_scale[:, None] - - # Online softmax accumulation - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - re_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - acc *= re_scale - acc += tl.sum(p[:, None] * v, 0) - - e_sum = e_sum * re_scale + tl.sum(p, 0) - e_max = n_e_max - - offs_mid_o = ( - cur_batch * stride_mid_ob - + cur_head * stride_mid_oh - + split_kv_id * stride_mid_os - + offs_dv - ) - - tl.store( - Att_Out + offs_mid_o, - acc / e_sum, - mask=mask_dv, - ) - - offs_mid_o_1 = ( - cur_batch * stride_mid_ob - + cur_head * stride_mid_oh - + split_kv_id * stride_mid_os - ) // Lv - - tl.store( - Att_Lse + offs_mid_o_1, - e_max + tl.log(e_sum), - ) - - -@triton.jit -def _fwd_kernel_int8_stage2( - Mid_O, - Mid_O_1, - O, - kv_indptr, - last_page_len, - num_kv_splits, - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - stride_obs, - stride_oh, - MAX_KV_SPLITS: tl.constexpr, - MIN_BLOCK_KV: tl.constexpr, - BLOCK_DV: tl.constexpr, - Lv: tl.constexpr, - PAGE_SIZE: tl.constexpr, -): - """Stage 2: Reduce split outputs via log-sum-exp merge.""" - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - - cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) - cur_last_page_len = tl.load(last_page_len + cur_batch) - cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len - kv_splits = tl.load(num_kv_splits + cur_batch) - - offs_d = tl.arange(0, BLOCK_DV) - mask_d = offs_d < Lv - - e_sum = 0.0 - e_max = -float("inf") - acc = tl.zeros([BLOCK_DV], dtype=tl.float32) - - offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d - offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv - kv_len_per_split = ( - tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV - ) - - for split_kv_id in range(0, MAX_KV_SPLITS): - split_kv_start = kv_len_per_split * split_kv_id - split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - - if split_kv_end > split_kv_start: - tv = tl.load( - Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 - ) - tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) - n_e_max = tl.maximum(tlogic, e_max) - - old_scale = tl.exp(e_max - n_e_max) - acc *= old_scale - exp_logic = tl.exp(tlogic - n_e_max) - acc += exp_logic * tv - - e_sum = e_sum * old_scale + exp_logic - e_max = n_e_max - - tl.store( - O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, - acc / e_sum, - mask=mask_d, - ) - - -def paged_decode_int8( - q: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 - k_buffer: torch.Tensor, # int8 paged K cache - v_buffer: torch.Tensor, # int8 paged V cache - k_scale_buffer: torch.Tensor, # fp16 scale for K - v_scale_buffer: torch.Tensor, # fp16 scale for V - o: torch.Tensor, # [batch, num_qo_heads, head_dim] bf16 output - kv_indptr: torch.Tensor, # [batch + 1] int32, page-level - kv_indices: torch.Tensor, # page indices - last_page_len: torch.Tensor, # [batch] int32 - num_kv_splits: torch.Tensor, # [batch] int32 - max_kv_splits: int, - sm_scale: float, - page_size: int, - logit_cap: float = 0.0, - att_out: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits, Lv] - att_lse: torch.Tensor = None, # optional pre-allocated [batch, head_num, max_kv_splits] -): - """ - Paged decode attention with int8 KV cache and inline dequantization. - - kv_indptr is page-level. last_page_len specifies valid tokens in the last page - for each batch entry. Total tokens = (num_pages - 1) * page_size + last_page_len. - """ - batch = q.shape[0] - head_num = q.shape[1] - Lk = q.shape[2] - Lv = Lk - - BLOCK_DMODEL = triton.next_power_of_2(Lk) - BLOCK_DV = triton.next_power_of_2(Lv) - BLOCK_N = 64 - MAX_KV_SPLITS = max_kv_splits - - kv_group_num = head_num - - num_warps = 4 if kv_group_num == 1 else 2 - - # Use pre-allocated buffers if provided, otherwise allocate - if att_out is None: - att_out = torch.empty( - (batch, head_num, MAX_KV_SPLITS, Lv), - dtype=torch.float32, - device=q.device, - ) - else: - att_out = att_out[:batch] - if att_lse is None: - att_lse = torch.empty( - (batch, head_num, MAX_KV_SPLITS), - dtype=torch.float32, - device=q.device, - ) - else: - att_lse = att_lse[:batch] - - stride_buf_kbs = k_buffer.shape[-1] - stride_buf_vbs = v_buffer.shape[-1] - - grid_stage1 = (batch, head_num, MAX_KV_SPLITS) - _fwd_kernel_int8_stage1[grid_stage1]( - q, - k_buffer, - v_buffer, - k_scale_buffer, - v_scale_buffer, - sm_scale, - kv_indptr, - kv_indices, - last_page_len, - att_out, - att_lse, - num_kv_splits, - q.stride(0), - q.stride(1), - stride_buf_kbs, - stride_buf_vbs, - att_out.stride(0), - att_out.stride(1), - att_out.stride(2), - kv_group_num=kv_group_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_DV=BLOCK_DV, - BLOCK_N=BLOCK_N, - MIN_BLOCK_KV=_MIN_BLOCK_KV, - logit_cap=logit_cap, - num_warps=num_warps, - num_stages=2, - Lk=Lk, - Lv=Lv, - PAGE_SIZE=page_size, - ) - - grid_stage2 = (batch, head_num) - _fwd_kernel_int8_stage2[grid_stage2]( - att_out, - att_lse, - o, - kv_indptr, - last_page_len, - num_kv_splits, - att_out.stride(0), - att_out.stride(1), - att_out.stride(2), - o.stride(0), - o.stride(1), - MAX_KV_SPLITS=MAX_KV_SPLITS, - MIN_BLOCK_KV=_MIN_BLOCK_KV, - BLOCK_DV=BLOCK_DV, - Lv=Lv, - PAGE_SIZE=page_size, - num_warps=4, - num_stages=2, - ) diff --git a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py b/vortex_torch/cache/triton_kernels/paged_prefill_int8.py deleted file mode 100644 index 89279833..00000000 --- a/vortex_torch/cache/triton_kernels/paged_prefill_int8.py +++ /dev/null @@ -1,168 +0,0 @@ -""" -OOM-safe bf16 fallback for int8 KV-cache prefill. - -Instead of implementing full 2D-tiled Triton prefill with int8 dequantization, -this module dequantizes only the accessed KV pages into a compact temporary -bf16 buffer and remaps indices so FlashInfer can operate on the compact buffer. - -This avoids dequantizing the entire global cache buffer. -""" - -import torch -import triton -import triton.language as tl - - -@triton.jit -def _dequant_pages_kernel( - src_int8, # int8 paged buffer [num_pages, page_size, head_dim] flat - src_scale, # fp16 scale buffer [num_pages, page_size, 1] flat - dst_bf16, # bf16 compact buffer [num_accessed_pages, page_size, head_dim] flat - page_indices, # int32 [num_accessed_pages] — which global pages to dequant - NUM_PAGES: tl.constexpr, - PAGE_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_DIM: tl.constexpr, -): - """Dequantize selected int8 pages to bf16 compact buffer.""" - page_idx = tl.program_id(0) # index into page_indices - token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) - - if page_idx >= NUM_PAGES: - return - - global_page_id = tl.load(page_indices + page_idx) - dims = tl.arange(0, BLOCK_DIM) - mask_dim = dims < HEAD_DIM - - # Source: global_page_id * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims - src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims - val_int8 = tl.load(src_int8 + src_offset, mask=mask_dim, other=0).to(tl.float32) - - # Scale: global_page_id * PAGE_SIZE + token_idx - scale_offset = global_page_id * PAGE_SIZE + token_idx - scale = tl.load(src_scale + scale_offset).to(tl.float32) - - val_bf16 = (val_int8 * scale).to(tl.bfloat16) - - # Destination: page_idx * PAGE_SIZE * HEAD_DIM + token_idx * HEAD_DIM + dims - dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims - tl.store(dst_bf16 + dst_offset, val_bf16, mask=mask_dim) - - -def dequant_paged_int8_to_bf16( - src_int8: torch.Tensor, # int8 [num_pages, page_size, head_dim] - src_scale: torch.Tensor, # fp16 [num_pages, page_size, 1] - page_indices: torch.Tensor, # int32 [num_accessed_pages] - page_size: int, - head_dim: int, - out: torch.Tensor = None, # optional pre-allocated bf16 [>=num_accessed_pages, page_size, head_dim] -) -> torch.Tensor: - """ - Dequantize only the accessed pages from int8 cache to a compact bf16 buffer. - - If `out` is provided, writes into it (must have room for num_accessed_pages). - Otherwise allocates a new buffer. - - Returns: - bf16 tensor of shape [num_accessed_pages, page_size, head_dim] - """ - num_accessed_pages = page_indices.shape[0] - if num_accessed_pages == 0: - if out is not None: - return out[:0] - return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src_int8.device) - - if out is not None: - dst_bf16 = out[:num_accessed_pages] - else: - dst_bf16 = torch.empty( - (num_accessed_pages, page_size, head_dim), - dtype=torch.bfloat16, - device=src_int8.device, - ) - - BLOCK_DIM = triton.next_power_of_2(head_dim) - - grid = (num_accessed_pages, page_size) - _dequant_pages_kernel[grid]( - src_int8, - src_scale, - dst_bf16, - page_indices, - NUM_PAGES=num_accessed_pages, - PAGE_SIZE=page_size, - HEAD_DIM=head_dim, - BLOCK_DIM=BLOCK_DIM, - ) - - return dst_bf16 - - -@triton.jit -def _dequant_pages_inplace_kernel( - src_int8, # int8 paged buffer flat - src_scale, # scale buffer flat (one scale per token slot) - dst_bf16, # bf16 destination buffer (same page layout as src) - page_indices, # int32 [num_pages] — which global pages to dequant - NUM_PAGES: tl.constexpr, - PAGE_SIZE: tl.constexpr, - HEAD_DIM: tl.constexpr, - BLOCK_DIM: tl.constexpr, -): - """Dequantize selected int8 pages to bf16, writing to the SAME page positions in dst.""" - page_idx = tl.program_id(0) # index into page_indices - token_idx = tl.program_id(1) # token within page [0, PAGE_SIZE) - - if page_idx >= NUM_PAGES: - return - - global_page_id = tl.load(page_indices + page_idx) - dims = tl.arange(0, BLOCK_DIM) - mask_dim = dims < HEAD_DIM - - # Source and destination use the SAME offset (in-place layout) - offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims - val_int8 = tl.load(src_int8 + offset, mask=mask_dim, other=0).to(tl.float32) - - scale_offset = global_page_id * PAGE_SIZE + token_idx - scale = tl.load(src_scale + scale_offset).to(tl.float32) - - val_bf16 = (val_int8 * scale).to(tl.bfloat16) - - # Write to the SAME page position in dst (not compacted) - tl.store(dst_bf16 + offset, val_bf16, mask=mask_dim) - - -def dequant_paged_int8_to_bf16_inplace( - src_int8: torch.Tensor, # int8 paged cache (flat) - src_scale: torch.Tensor, # fp16 scale buffer (flat) - dst_bf16: torch.Tensor, # bf16 destination (same shape as src_int8) - page_indices: torch.Tensor, # int32 [num_pages] — which pages to dequant - page_size: int, - head_dim: int, -) -> None: - """ - Dequantize selected pages from int8 cache to bf16 IN-PLACE. - - Unlike dequant_paged_int8_to_bf16 (which compacts into a dense buffer), - this writes to the SAME page positions in dst_bf16, preserving the paged layout. - Used to populate the bf16 working buffer for forward_cache (centroid computation). - """ - num_pages = page_indices.shape[0] - if num_pages == 0: - return - - BLOCK_DIM = triton.next_power_of_2(head_dim) - - grid = (num_pages, page_size) - _dequant_pages_inplace_kernel[grid]( - src_int8, - src_scale, - dst_bf16, - page_indices, - NUM_PAGES=num_pages, - PAGE_SIZE=page_size, - HEAD_DIM=head_dim, - BLOCK_DIM=BLOCK_DIM, - ) diff --git a/vortex_torch/cache/triton_kernels/set_kv.py b/vortex_torch/cache/triton_kernels/set_kv.py index 6b289df3..58468cc0 100644 --- a/vortex_torch/cache/triton_kernels/set_kv.py +++ b/vortex_torch/cache/triton_kernels/set_kv.py @@ -241,3 +241,498 @@ def set_kv_buffer_fp8_launcher( v_scale=v_scale, ) + +# --------------------------------------------------------------------------- +# Dequantization kernels (read direction: quantized paged cache → bf16) +# --------------------------------------------------------------------------- + +@triton.jit +def _dequant_pages_kernel( + src, # quantized paged buffer flat + src_scale, # per-token scale buffer flat (int8 only) + dst, # bf16 destination buffer flat + page_indices, # int32 page indices to dequant + NUM_PAGES, + PAGE_SIZE: tl.constexpr, + HEAD_DIM: tl.constexpr, + BLOCK_DIM: tl.constexpr, + QUANT_TYPE: tl.constexpr, # 1: int8, 2: e4m3, 3: e5m2 + tensor_scale, # float: per-tensor scale (fp8 only) + COMPACT: tl.constexpr, # True: compact dst; False: in-place dst +): + """Unified dequant kernel for selected pages → bf16. + + QUANT_TYPE==1: load int8, multiply by per-token scale from src_scale. + QUANT_TYPE==2: load uint8, bitcast to float8e4nv, multiply by tensor_scale. + QUANT_TYPE==3: load uint8, bitcast to float8e5, multiply by tensor_scale. + COMPACT==True: dst offset uses page_idx (compact buffer). + COMPACT==False: dst offset uses global_page_id (in-place). + """ + page_idx = tl.program_id(0) + token_idx = tl.program_id(1) + + if page_idx >= NUM_PAGES: + return + + global_page_id = tl.load(page_indices + page_idx) + dims = tl.arange(0, BLOCK_DIM) + mask_dim = dims < HEAD_DIM + + src_offset = (global_page_id * PAGE_SIZE + token_idx) * HEAD_DIM + dims + scale_offset = global_page_id * PAGE_SIZE + token_idx + + if QUANT_TYPE == 1: + val = tl.load(src + src_offset, mask=mask_dim, other=0).to(tl.float32) + scale = tl.load(src_scale + scale_offset).to(tl.float32) + result = (val * scale).to(tl.bfloat16) + elif QUANT_TYPE == 2: + raw = tl.load(src + src_offset, mask=mask_dim, other=0) + val = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) + result = (val * tensor_scale).to(tl.bfloat16) + else: # QUANT_TYPE == 3 + raw = tl.load(src + src_offset, mask=mask_dim, other=0) + val = raw.to(tl.float8e5, bitcast=True).to(tl.float32) + result = (val * tensor_scale).to(tl.bfloat16) + + if COMPACT: + dst_offset = (page_idx * PAGE_SIZE + token_idx) * HEAD_DIM + dims + else: + dst_offset = src_offset # same position as source + + tl.store(dst + dst_offset, result, mask=mask_dim) + + +def dequant_pages_to_bf16( + src: torch.Tensor, + src_scale: torch.Tensor, + page_indices: torch.Tensor, + page_size: int, + head_dim: int, + quant_type: int = 1, + tensor_scale: float = 1.0, + out: torch.Tensor = None, +) -> torch.Tensor: + """Dequant selected pages to compact bf16 buffer. + + Args: + quant_type: 1=int8 (per-token scale), 2=fp8 e4m3, 3=fp8 e5m2. + tensor_scale: per-tensor scale (fp8 only, ignored for int8). + out: optional pre-allocated bf16 buffer. + """ + num_accessed_pages = page_indices.shape[0] + if num_accessed_pages == 0: + if out is not None: + return out[:0] + return torch.empty((0, page_size, head_dim), dtype=torch.bfloat16, device=src.device) + + if out is not None: + dst = out[:num_accessed_pages] + else: + dst = torch.empty( + (num_accessed_pages, page_size, head_dim), + dtype=torch.bfloat16, + device=src.device, + ) + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_accessed_pages, page_size) + _dequant_pages_kernel[grid]( + src, src_scale, dst, page_indices, + NUM_PAGES=num_accessed_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + COMPACT=True, + ) + + return dst + + +def dequant_pages_to_bf16_inplace( + src: torch.Tensor, + src_scale: torch.Tensor, + dst: torch.Tensor, + page_indices: torch.Tensor, + page_size: int, + head_dim: int, + quant_type: int = 1, + tensor_scale: float = 1.0, +) -> None: + """Dequant selected pages in-place (same page positions in dst). + + Args: + quant_type: 1=int8 (per-token scale), 2=fp8 e4m3, 3=fp8 e5m2. + tensor_scale: per-tensor scale (fp8 only, ignored for int8). + """ + num_pages = page_indices.shape[0] + if num_pages == 0: + return + + BLOCK_DIM = triton.next_power_of_2(head_dim) + + grid = (num_pages, page_size) + _dequant_pages_kernel[grid]( + src, src_scale, dst, page_indices, + NUM_PAGES=num_pages, + PAGE_SIZE=page_size, + HEAD_DIM=head_dim, + BLOCK_DIM=BLOCK_DIM, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + COMPACT=False, + ) + + +# --------------------------------------------------------------------------- +# Paged decode attention (unified quant_type-parameterized) +# --------------------------------------------------------------------------- + +_MIN_BLOCK_KV = 32 + + +@triton.jit +def _tanh(x): + return 2 * tl.sigmoid(2 * x) - 1 + + +@triton.jit +def _fwd_kernel_paged_decode_stage1( + Q, + K_Buffer, + V_Buffer, + K_Scale_Buffer, + V_Scale_Buffer, + sm_scale, + kv_indptr, + kv_indices, + last_page_len, + Att_Out, + Att_Lse, + num_kv_splits, + stride_qbs, + stride_qh, + stride_buf_kbs, + stride_buf_vbs, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + kv_group_num: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_N: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, + QUANT_TYPE: tl.constexpr, # 0: bf16, 1: int8, 2: e4m3, 3: e5m2 + tensor_scale, # per-tensor scale for fp8 +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + e_max = -float("inf") + e_sum = 0.0 + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + if split_kv_end > split_kv_start: + q = tl.load(Q + off_q, mask=mask_d, other=0.0).to(tl.float32) + + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + mask_n = offs_n < split_kv_end + + page_indices_in_seq = offs_n // PAGE_SIZE + in_page_offsets = offs_n % PAGE_SIZE + page_ids = tl.load( + kv_indices + cur_batch_kv_start_idx + page_indices_in_seq, + mask=mask_n, other=0, + ) + kv_loc = page_ids * PAGE_SIZE + in_page_offsets + + # Load K with quant-type-dependent dequantization + offs_buf_k = kv_loc[:, None] * stride_buf_kbs + offs_d[None, :] + if QUANT_TYPE == 0: + k = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ).to(tl.float32) + elif QUANT_TYPE == 1: + k_int8 = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ).to(tl.float32) + k_scale = tl.load( + K_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, + ).to(tl.float32) + k = k_int8 * k_scale[:, None] + elif QUANT_TYPE == 2: + raw = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ) + k = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * tensor_scale + else: # QUANT_TYPE == 3 + raw = tl.load( + K_Buffer + offs_buf_k, + mask=mask_n[:, None] & mask_d[None, :], other=0, + ) + k = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * tensor_scale + + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * _tanh(qk / logit_cap) + + qk = tl.where(mask_n, qk, float("-inf")) + + # Load V with quant-type-dependent dequantization + offs_buf_v = kv_loc[:, None] * stride_buf_vbs + offs_dv[None, :] + if QUANT_TYPE == 0: + v = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ).to(tl.float32) + elif QUANT_TYPE == 1: + v_int8 = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ).to(tl.float32) + v_scale = tl.load( + V_Scale_Buffer + kv_loc, mask=mask_n, other=1.0, + ).to(tl.float32) + v = v_int8 * v_scale[:, None] + elif QUANT_TYPE == 2: + raw = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ) + v = raw.to(tl.float8e4nv, bitcast=True).to(tl.float32) * tensor_scale + else: # QUANT_TYPE == 3 + raw = tl.load( + V_Buffer + offs_buf_v, + mask=mask_n[:, None] & mask_dv[None, :], other=0, + ) + v = raw.to(tl.float8e5, bitcast=True).to(tl.float32) * tensor_scale + + # Online softmax accumulation + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv + ) + + tl.store(Att_Out + offs_mid_o, acc / e_sum, mask=mask_dv) + + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + ) // Lv + + tl.store(Att_Lse + offs_mid_o_1, e_max + tl.log(e_sum)) + + +@triton.jit +def _fwd_kernel_paged_decode_stage2( + Mid_O, + Mid_O_1, + O, + kv_indptr, + last_page_len, + num_kv_splits, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + MAX_KV_SPLITS: tl.constexpr, + MIN_BLOCK_KV: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, + PAGE_SIZE: tl.constexpr, +): + """Stage 2: Reduce split outputs via log-sum-exp merge.""" + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + cur_batch_num_pages = tl.load(kv_indptr + cur_batch + 1) - tl.load(kv_indptr + cur_batch) + cur_last_page_len = tl.load(last_page_len + cur_batch) + cur_batch_seq_len = (cur_batch_num_pages - 1) * PAGE_SIZE + cur_last_page_len + kv_splits = tl.load(num_kv_splits + cur_batch) + + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = (cur_batch * stride_mid_ob + cur_head * stride_mid_oh) // Lv + kv_len_per_split = ( + tl.cdiv(tl.cdiv(cur_batch_seq_len, kv_splits), MIN_BLOCK_KV) * MIN_BLOCK_KV + ) + + for split_kv_id in range(0, MAX_KV_SPLITS): + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) + + if split_kv_end > split_kv_start: + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O_1 + offs_logic + split_kv_id * stride_mid_os // Lv) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def paged_decode( + q: torch.Tensor, + k_buffer: torch.Tensor, + v_buffer: torch.Tensor, + o: torch.Tensor, + kv_indptr: torch.Tensor, + kv_indices: torch.Tensor, + last_page_len: torch.Tensor, + num_kv_splits: torch.Tensor, + max_kv_splits: int, + sm_scale: float, + page_size: int, + quant_type: int = 0, + k_scale_buffer: torch.Tensor = None, + v_scale_buffer: torch.Tensor = None, + tensor_scale: float = 1.0, + logit_cap: float = 0.0, + att_out: torch.Tensor = None, + att_lse: torch.Tensor = None, +): + """Unified paged decode attention. + + Args: + quant_type: Controls K/V loading: + 0: bf16 (k_scale_buffer/v_scale_buffer unused) + 1: int8 with per-token scales (k_scale_buffer/v_scale_buffer required) + 2: fp8 e4m3 with per-tensor scale (tensor_scale required) + 3: fp8 e5m2 with per-tensor scale (tensor_scale required) + """ + batch = q.shape[0] + head_num = q.shape[1] + Lk = q.shape[2] + Lv = Lk + + BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) + BLOCK_N = 128 + MAX_KV_SPLITS = max_kv_splits + + kv_group_num = head_num + num_warps = 4 + + if att_out is None: + att_out = torch.empty( + (batch, head_num, MAX_KV_SPLITS, Lv), + dtype=torch.float32, device=q.device, + ) + else: + att_out = att_out[:batch] + if att_lse is None: + att_lse = torch.empty( + (batch, head_num, MAX_KV_SPLITS), + dtype=torch.float32, device=q.device, + ) + else: + att_lse = att_lse[:batch] + + stride_buf_kbs = k_buffer.shape[-1] + stride_buf_vbs = v_buffer.shape[-1] + + # Use dummy tensors for scale buffers when not needed + _k_scale = k_scale_buffer if k_scale_buffer is not None else k_buffer + _v_scale = v_scale_buffer if v_scale_buffer is not None else v_buffer + + grid_stage1 = (batch, head_num, MAX_KV_SPLITS) + _fwd_kernel_paged_decode_stage1[grid_stage1]( + q, k_buffer, v_buffer, + _k_scale, _v_scale, + sm_scale, kv_indptr, kv_indices, last_page_len, + att_out, att_lse, num_kv_splits, + q.stride(0), q.stride(1), + stride_buf_kbs, stride_buf_vbs, + att_out.stride(0), att_out.stride(1), att_out.stride(2), + kv_group_num=kv_group_num, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, + BLOCK_N=BLOCK_N, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + logit_cap=logit_cap, + num_warps=num_warps, + num_stages=2, + Lk=Lk, Lv=Lv, + PAGE_SIZE=page_size, + QUANT_TYPE=quant_type, + tensor_scale=tensor_scale, + ) + + grid_stage2 = (batch, head_num) + _fwd_kernel_paged_decode_stage2[grid_stage2]( + att_out, att_lse, o, + kv_indptr, last_page_len, num_kv_splits, + att_out.stride(0), att_out.stride(1), att_out.stride(2), + o.stride(0), o.stride(1), + MAX_KV_SPLITS=MAX_KV_SPLITS, + MIN_BLOCK_KV=_MIN_BLOCK_KV, + BLOCK_DV=BLOCK_DV, + Lv=Lv, + PAGE_SIZE=page_size, + num_warps=4, + num_stages=2, + ) + From 19c7fcc573019b87a17f7629eb3602ce8dfd1752 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 31 Mar 2026 05:00:36 +0000 Subject: [PATCH 13/24] =?UTF-8?q?Add=20TopK=20benchmarking=20suite=20and?= =?UTF-8?q?=20related=20scripts=20-=20Introduced=20a=20comprehensive=20ben?= =?UTF-8?q?chmarking=20suite=20for=20TopK=20kernel=20variants,=20measuring?= =?UTF-8?q?=20kernel-level=20latency.=20-=20Added=20scripts=20for=20offlin?= =?UTF-8?q?e=20calibration=20of=20TopK=20mapping=20modes,=20including:#=20?= =?UTF-8?q?0:=20None=20=20=20=20=20=20=20=20=20=20=20=E2=80=94=20original?= =?UTF-8?q?=20fp16=20bit-pattern=20bucketing=20#=201:=20LUT=20CDF=20=20=20?= =?UTF-8?q?=20=20=20=20=20=E2=80=94=20LUT-based=20CDF=20equalization=20(ca?= =?UTF-8?q?librated)=20#=202:=20Quantile=20=20=20=20=20=20=20=E2=80=94=20p?= =?UTF-8?q?iecewise-linear=20quantile=20mapping=20(calibrated)=20#=203:=20?= =?UTF-8?q?Power=20=20=20=20=20=20=20=20=20=20=E2=80=94=20y=20=3D=20sign(x?= =?UTF-8?q?)=20*=20|x|^p=20#=204:=20Log=20=20=20=20=20=20=20=20=20=20=20?= =?UTF-8?q?=20=E2=80=94=20y=20=3D=20sign(x)=20*=20log(|x|=20+=201)=20#=205?= =?UTF-8?q?:=20Index=20Cache=20=20=20=20=E2=80=94=20reuse=20previous=20lay?= =?UTF-8?q?er's=20indices=20#=206:=20Asinh=20=20=20=20=20=20=20=20=20=20?= =?UTF-8?q?=E2=80=94=20y=20=3D=20asinh(beta=20*=20x)=20#=207:=20Log1p=20?= =?UTF-8?q?=20=20=20=20=20=20=20=20=20=E2=80=94=20y=20=3D=20sign(x)=20*=20?= =?UTF-8?q?log1p(alpha=20*=20|x|)=20#=208:=20Trunc8=20=20=20=20=20=20=20?= =?UTF-8?q?=20=20=E2=80=94=20bf16=20upper-8-bit=20bucketing=20-=20=20Addin?= =?UTF-8?q?g=20various=20remap=20functions=20for=20the=20bucket=20sort=20i?= =?UTF-8?q?n=20sglang=20topk=20kernel,=20with=20evaluation=20and=20visuali?= =?UTF-8?q?zation=20scripts.=20-=20Implemented=20analysis=20tools=20for=20?= =?UTF-8?q?TopK=20distribution=20profiling.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 3 + benchmarks/README.md | 89 + benchmarks/__init__.py | 0 benchmarks/analyze_topk_distribution.py | 479 +++++ benchmarks/autotune_topk_mapping.py | 378 ++++ benchmarks/bench_topk.py | 587 +++++ benchmarks/calibrate_topk.py | 153 ++ benchmarks/greedy_layer_search.py | 117 + benchmarks/profile_topk_distribution.py | 132 ++ csrc/register.cc | 19 +- csrc/register.h | 19 +- csrc/topk.cu | 2 +- csrc/topk_mapping.cuh | 148 ++ csrc/topk_sglang.cu | 1905 ++++++++++------- examples/README.md | 399 ++++ examples/run_distribution_analysis.sh | 141 ++ examples/run_distribution_analysis_new.sh | 150 ++ examples/run_topk_benchmark.sh | 294 +++ examples/verify_algo.py | 78 +- examples/verify_algo.sh | 5 +- examples/verify_algo_quant.sh | 18 +- examples/verify_algo_topk_mapping.sh | 175 ++ .../verify_algo_topk_mapping_indexcache.sh | 45 + examples/verify_algo_topk_mapping_new.sh | 128 ++ third_party/sglang | 2 +- vortex_torch/indexer/context.py | 26 +- vortex_torch/indexer/output_func.py | 60 +- 27 files changed, 4698 insertions(+), 854 deletions(-) create mode 100644 benchmarks/README.md create mode 100644 benchmarks/__init__.py create mode 100644 benchmarks/analyze_topk_distribution.py create mode 100644 benchmarks/autotune_topk_mapping.py create mode 100644 benchmarks/bench_topk.py create mode 100644 benchmarks/calibrate_topk.py create mode 100644 benchmarks/greedy_layer_search.py create mode 100644 benchmarks/profile_topk_distribution.py create mode 100644 csrc/topk_mapping.cuh create mode 100644 examples/README.md create mode 100755 examples/run_distribution_analysis.sh create mode 100755 examples/run_distribution_analysis_new.sh create mode 100755 examples/run_topk_benchmark.sh create mode 100644 examples/verify_algo_topk_mapping.sh create mode 100644 examples/verify_algo_topk_mapping_indexcache.sh create mode 100644 examples/verify_algo_topk_mapping_new.sh diff --git a/.gitignore b/.gitignore index 931c8cae..6a904c74 100644 --- a/.gitignore +++ b/.gitignore @@ -236,3 +236,6 @@ compile_commands.json # Rust lib Cargo.lock + +/examples/results +*.npy \ No newline at end of file diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 00000000..e390344d --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,89 @@ +# TopK Kernel Benchmarking Suite + +Standalone benchmarking for Vortex's three topk kernel variants, measuring kernel-level latency isolated from the full SGLang inference pipeline. + +## Kernel Variants + +| Kernel | Description | +|--------|-------------| +| `naive` | CUB radix sort (bf16 only) | +| `sglang_m0` | Two-stage hierarchical radix sort, no mapping | +| `sglang_m1` | + LUT mapping (requires `--lut-path`) | +| `sglang_m2` | + Quantile mapping (requires `--quantiles-path`) | +| `sglang_m3` | + Power mapping (configurable via `--mapping-power`) | +| `sglang_m4` | + Log mapping | + +## Quick Start + +```bash +# Activate environment +source /scr/dataset/yuke/xinrui/uv_env/vortex/bin/activate + +# Quick single-config test +python benchmarking/bench_topk.py \ + --batch-sizes 8 \ + --seq-lens 4096 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --repeat 200 + +# Sweep with histogram analysis +python benchmarking/bench_topk.py \ + --batch-sizes 4 8 16 \ + --seq-lens 2048 4096 8192 \ + --topk-vals 30 64 \ + --num-kv-heads 2 \ + --repeat 100 \ + --histogram + +# Full sweep with JSON output +python benchmarking/bench_topk.py \ + --output-json benchmarking/results.json \ + --histogram +``` + +## CLI Options + +| Argument | Default | Description | +|----------|---------|-------------| +| `--batch-sizes` | 1 4 8 16 32 64 | Batch sizes to sweep | +| `--seq-lens` | 1024 2048 4096 8192 | Sequence lengths to sweep | +| `--topk-vals` | 16 30 64 | TopK values to sweep | +| `--num-kv-heads` | 2 4 8 | KV head counts to sweep | +| `--page-size` | 16 | Tokens per page | +| `--reserved-bos` | 1 | Reserved BOS pages | +| `--reserved-eos` | 2 | Reserved EOS pages | +| `--score-dtype` | bfloat16 | Score tensor dtype (bfloat16 or float32) | +| `--distributions` | normal lognormal uniform | Score distributions to test | +| `--warmup` | 10 | Warmup iterations | +| `--repeat` | 100 | Timed iterations | +| `--mapping-power` | 0.5 | Power parameter for mode=3 | +| `--lut-path` | None | Path to .npy uint8[256] LUT for mode=1 | +| `--quantiles-path` | None | Path to .npy float32[256] quantiles for mode=2 | +| `--output-json` | None | Save results to JSON file | +| `--filter-kernels` | None | Only run specific kernels (e.g., `naive sglang_m0`) | +| `--histogram` | False | Collect bin distribution statistics | + +## Histogram Analysis + +When `--histogram` is passed, each config additionally runs `topk_profile_histogram` and reports: + +- **max/mean ratio**: Peak bin count divided by average (lower = more uniform) +- **Gini coefficient**: Inequality measure of bin distribution (0 = perfectly uniform) +- **nonzero_bins**: How many of the 256 bins received any values + +This shows whether mapping modes improve bin uniformity for a given score distribution. + +## Output Format + +``` +TopK Kernel Benchmark Results +GPU: NVIDIA H100 80GB HBM3 | SM count: 132 + +bs=8 | seq=4096 | topk=30 | heads=2 | pages/seg=256 | dist=normal + naive : 0.0420ms (median) +/- 0.0030ms [min=0.0390, max=0.0510] + sglang mode=0 : 0.0310ms (median) +/- 0.0020ms [min=0.0290, max=0.0380] + sglang mode=3 : 0.0330ms (median) +/- 0.0020ms [min=0.0300, max=0.0400] + sglang mode=4 : 0.0320ms (median) +/- 0.0020ms [min=0.0300, max=0.0390] + histogram stats : max/mean=3.99 gini=0.568 nonzero_bins=70/256 +``` diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/benchmarks/analyze_topk_distribution.py b/benchmarks/analyze_topk_distribution.py new file mode 100644 index 00000000..7d944667 --- /dev/null +++ b/benchmarks/analyze_topk_distribution.py @@ -0,0 +1,479 @@ +""" +TopK distribution analysis and visualization. + +Loads profiling data from: + - profile_topk_distribution.py output (.npz): raw histograms, LUT tables + - bench_topk.py output (.json): benchmark results + per-mode histogram data + +Produces visualization plots for evaluating mapping mode effectiveness. + +Usage: + python scripts/analyze_topk_distribution.py \ + --bench-json bench_hitrate.json \ + --output-dir plots/ + + python scripts/analyze_topk_distribution.py \ + --profile-npz profile_output.npz \ + --bench-json bench_hitrate.json \ + --output-dir plots/ --max-segments 8 +""" + +import argparse +import json +import os +from typing import Optional + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import matplotlib.colors as mcolors +import numpy as np + +# Canonical mapping mode names — shared across all profiling/analysis tools +MAPPING_MODE_NAMES = { + 0: "None", + 1: "LUT CDF", + 2: "Quantile", + 3: "Power", + 4: "Log", + 5: "Index Cache", + 6: "Asinh", + 7: "Log1p", + 8: "Trunc8", +} + +MAPPING_MODE_FORMULAS = { + 0: "None (fp16 bucketing)", + 1: "LUT CDF (calibrated)", + 2: "Quantile (calibrated)", + 3: "Power: sign(x)*|x|^p", + 4: "Log: sign(x)*log(|x|+1)", + 5: "Index Cache", + 6: "Asinh: asinh(beta*x)", + 7: "Log1p: sign(x)*log1p(alpha*|x|)", + 8: "Trunc8: bf16 upper-8-bit bucketing", +} + + +def _mode_key_to_display(mode_key: str) -> str: + """Convert a mode key like 'mode_3' or 'mode_3_Power' to a display name.""" + # Handle new format: "mode_3_Power" + parts = mode_key.split("_", 2) + if len(parts) >= 3: + return parts[2] # e.g. "Power" + # Handle old format: "mode_3" + try: + mode_num = int(parts[1]) + return MAPPING_MODE_NAMES.get(mode_num, mode_key) + except (IndexError, ValueError): + return mode_key + + +def _mode_key_to_number(mode_key: str) -> int: + """Extract the mode number from a key like 'mode_3' or 'mode_3_Power'.""" + parts = mode_key.split("_") + try: + return int(parts[1]) + except (IndexError, ValueError): + return -1 + + +def compute_per_segment_stats(histograms: np.ndarray) -> dict: + """Compute per-row Gini coefficient and max/mean ratio. + + Args: + histograms: [num_segments, 256] array of bin counts + + Returns: + dict with 'gini' and 'max_mean' arrays of shape [num_segments] + """ + num_seg = histograms.shape[0] + ginis = np.zeros(num_seg) + max_means = np.zeros(num_seg) + + for i in range(num_seg): + row = histograms[i].astype(np.float64) + nonzero = row[row > 0] + if len(nonzero) == 0: + continue + + max_means[i] = nonzero.max() / nonzero.mean() + + # Gini coefficient + sorted_vals = np.sort(nonzero) + n = len(sorted_vals) + index = np.arange(1, n + 1, dtype=np.float64) + ginis[i] = (2.0 * (index * sorted_vals).sum() / (n * sorted_vals.sum()) - (n + 1) / n) + ginis[i] = max(0.0, ginis[i]) + + return {"gini": ginis, "max_mean": max_means} + + +def plot_bin_distribution(histograms: np.ndarray, output_dir: str, max_segments: int = 4): + """Plot 256-bin bar chart per segment (first N segments).""" + num_seg = min(histograms.shape[0], max_segments) + for i in range(num_seg): + fig, ax = plt.subplots(figsize=(12, 4)) + ax.bar(range(256), histograms[i], width=1.0, color="steelblue", edgecolor="none") + ax.set_xlabel("Bin") + ax.set_ylabel("Count") + ax.set_title(f"Segment {i}: 256-bin histogram") + ax.set_xlim(-1, 256) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"bin_dist_seg_{i}.png"), dpi=150) + plt.close(fig) + print(f" Saved {num_seg} bin distribution plots") + + +def plot_bin_heatmap(histograms: np.ndarray, output_dir: str): + """Heatmap: segments x bins, LogNorm colormap.""" + fig, ax = plt.subplots(figsize=(14, max(4, histograms.shape[0] * 0.15 + 1))) + # Add 1 to avoid log(0) + data = histograms.astype(np.float64) + 1 + im = ax.imshow( + data, + aspect="auto", + cmap="viridis", + norm=mcolors.LogNorm(vmin=1, vmax=data.max()), + interpolation="nearest", + ) + ax.set_xlabel("Bin") + ax.set_ylabel("Segment") + ax.set_title("Bin distribution heatmap (log scale)") + fig.colorbar(im, ax=ax, label="Count + 1") + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "bin_heatmap.png"), dpi=150) + plt.close(fig) + print(" Saved bin_heatmap.png") + + +def plot_before_after_mapping( + raw_histograms: np.ndarray, + lut_table: np.ndarray, + output_dir: str, + max_segments: int = 4, +): + """Side-by-side: raw histogram vs. LUT-remapped histogram.""" + num_seg = min(raw_histograms.shape[0], max_segments) + for i in range(num_seg): + raw = raw_histograms[i] + # Remap: redistribute counts through LUT + remapped = np.zeros(256, dtype=np.float64) + for bin_idx in range(256): + new_bin = int(lut_table[bin_idx]) + remapped[new_bin] += raw[bin_idx] + + fig, axes = plt.subplots(1, 2, figsize=(16, 4), sharey=True) + axes[0].bar(range(256), raw, width=1.0, color="steelblue", edgecolor="none") + axes[0].set_title(f"Segment {i}: Raw (mode=0)") + axes[0].set_xlabel("Bin") + axes[0].set_ylabel("Count") + + axes[1].bar(range(256), remapped, width=1.0, color="darkorange", edgecolor="none") + axes[1].set_title(f"Segment {i}: After LUT remap") + axes[1].set_xlabel("Bin") + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f"mapping_comparison_{i}.png"), dpi=150) + plt.close(fig) + print(f" Saved {num_seg} mapping comparison plots") + + +def plot_summary_table( + histograms: np.ndarray, + mode_stats_data: Optional[dict], + output_dir: str, +): + """Per-segment stats table: Gini, max/mean, resolution rate.""" + stats = compute_per_segment_stats(histograms) + num_seg = histograms.shape[0] + + col_labels = ["Segment", "Gini", "Max/Mean"] + cell_data = [] + for i in range(num_seg): + cell_data.append([str(i), f"{stats['gini'][i]:.3f}", f"{stats['max_mean'][i]:.2f}"]) + + fig, ax = plt.subplots(figsize=(6, max(2, num_seg * 0.4 + 1))) + ax.axis("off") + table = ax.table(cellText=cell_data, colLabels=col_labels, loc="center", cellLoc="center") + table.auto_set_font_size(False) + table.set_fontsize(9) + table.scale(1.0, 1.3) + ax.set_title("Per-segment distribution stats", fontsize=11, pad=10) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "summary_table.png"), dpi=150, bbox_inches="tight") + plt.close(fig) + print(" Saved summary_table.png") + + +def plot_distribution_comparison(dist_histograms: dict, output_dir: str, suffix: str = "", title: str = ""): + """Overlay 256-bin distributions for different data sources (uniform, normal, real). + + Args: + dist_histograms: {"uniform": [256], "normal": [256], "real": [256], ...} + output_dir: output directory for the plot + suffix: optional suffix for output filename (e.g. "_m0") + title: optional custom title for the plot + """ + names = list(dist_histograms.keys()) + n = len(names) + if n == 0: + print(" No distribution histograms to compare") + return + + fig, axes = plt.subplots(1, n, figsize=(6 * n, 4), squeeze=False) + axes = axes[0] + + for idx, name in enumerate(names): + counts = np.array(dist_histograms[name], dtype=np.float64) + ax = axes[idx] + ax.bar(range(256), counts, width=1.0, color="steelblue", edgecolor="none") + ax.set_xlabel("Bucket") + ax.set_ylabel("Count") + ax.set_xlim(-1, 256) + ax.set_title(name) + + # Annotate with stats + nonzero = counts[counts > 0] + if len(nonzero) > 0: + mean_val = nonzero.mean() + max_val = nonzero.max() + max_mean = max_val / mean_val if mean_val > 0 else 0.0 + sorted_vals = np.sort(nonzero) + nn = len(sorted_vals) + index = np.arange(1, nn + 1, dtype=np.float64) + gini = max(0.0, 2.0 * (index * sorted_vals).sum() / (nn * sorted_vals.sum()) - (nn + 1) / nn) + nz_bins = int(len(nonzero)) + else: + max_mean = gini = 0.0 + nz_bins = 0 + + stats_text = f"gini={gini:.3f}\nmax/mean={max_mean:.2f}\nbins={nz_bins}/256" + ax.text(0.97, 0.95, stats_text, transform=ax.transAxes, + fontsize=8, verticalalignment="top", horizontalalignment="right", + bbox=dict(boxstyle="round,pad=0.3", facecolor="wheat", alpha=0.7)) + + fig.suptitle(title or "Bucket Distribution Comparison", fontsize=13) + fig.tight_layout() + fname = f"distribution_comparison{suffix}.png" + fig.savefig(os.path.join(output_dir, fname), dpi=150) + plt.close(fig) + print(f" Saved {fname}") + + +def save_bucket_table(dist_histograms: dict, output_dir: str, filename: str = "bucket_counts.csv"): + """Write a CSV table listing the count per bucket for each distribution. + + Columns: bucket, dist1, dist2, ... (256 rows, one per bucket). + """ + import csv + + names = list(dist_histograms.keys()) + if not names: + return + + path = os.path.join(output_dir, filename) + with open(path, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(["bucket"] + names) + for b in range(256): + row = [b] + [int(dist_histograms[n][b]) for n in names] + writer.writerow(row) + + # Also print a compact summary to stdout (top-20 hottest buckets per dist) + print(f" Saved {path}") + for name in names: + counts = np.array(dist_histograms[name], dtype=np.int64) + total = counts.sum() + top_idx = np.argsort(counts)[::-1][:20] + print(f" [{name}] total={total} top-20 hottest buckets:") + for rank, idx in enumerate(top_idx): + if counts[idx] == 0: + break + pct = counts[idx] / total * 100 if total > 0 else 0 + print(f" #{rank+1:2d} bucket {idx:3d}: {counts[idx]:>10d} ({pct:5.1f}%)") + + +def plot_mapping_mode_comparison(mode_stats_data: dict, output_dir: str): + """Grouped bar chart comparing modes on gini and max/mean.""" + modes = sorted(mode_stats_data.keys()) + if not modes: + print(" No histogram data to plot mode comparison") + return + + mode_labels = [] + for m in modes: + label = _mode_key_to_display(m) + param = mode_stats_data[m].get("param") + if param: + label = f"{label} ({param})" + mode_labels.append(label) + ginis = [mode_stats_data[m]["gini"] for m in modes] + max_means = [mode_stats_data[m]["max_mean_ratio"] for m in modes] + + x = np.arange(len(modes)) + width = 0.3 + + fig, ax1 = plt.subplots(figsize=(10, 5)) + ax2 = ax1.twinx() + + bars1 = ax1.bar(x - width / 2, ginis, width, label="Gini", color="darkorange") + bars2 = ax2.bar(x + width / 2, max_means, width, label="Max/Mean", color="seagreen", alpha=0.7) + + ax1.set_xlabel("Mapping Mode") + ax1.set_ylabel("Gini") + ax2.set_ylabel("Max/Mean Ratio") + ax1.set_xticks(x) + ax1.set_xticklabels(mode_labels, rotation=15, ha="right") + ax1.set_ylim(0, 1.1) + ax1.set_title("Mapping Mode Comparison") + + # Combine legends + lines1, labels1 = ax1.get_legend_handles_labels() + lines2, labels2 = ax2.get_legend_handles_labels() + ax1.legend(lines1 + lines2, labels1 + labels2, loc="upper right") + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, "mode_comparison.png"), dpi=150) + plt.close(fig) + print(" Saved mode_comparison.png") + + +def main(): + parser = argparse.ArgumentParser(description="Analyze TopK bucket sort distribution") + parser.add_argument("--profile-npz", type=str, default=None, + help="Path to .npz from profile_topk_distribution.py") + parser.add_argument("--bench-json", type=str, default=None, + help="Path to JSON from bench_topk.py") + parser.add_argument("--output-dir", type=str, default="plots", + help="Directory for output plots") + parser.add_argument("--max-segments", type=int, default=4, + help="Max segments for per-segment plots") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to .npy raw_histograms from calibrate_topk.py (real-data bucket counts)") + args = parser.parse_args() + + if args.profile_npz is None and args.bench_json is None and args.real_histograms is None: + parser.error("At least one of --profile-npz, --bench-json, or --real-histograms is required") + + os.makedirs(args.output_dir, exist_ok=True) + print(f"Output directory: {args.output_dir}") + + raw_histograms = None + lut_table = None + mode_stats_data = None + + # Load profile data + if args.profile_npz: + print(f"\nLoading profile data from {args.profile_npz}") + data = np.load(args.profile_npz, allow_pickle=True) + if "raw_histograms" in data: + raw_histograms = data["raw_histograms"] + print(f" raw_histograms: {raw_histograms.shape}") + if "aggregate_lut" in data: + lut_table = data["aggregate_lut"] + print(f" aggregate_lut: {lut_table.shape}") + elif "lut_tables" in data: + # Use first LUT if aggregate not available + lut_table = data["lut_tables"] + if lut_table.ndim > 1: + lut_table = lut_table[0] + print(f" lut_table: {lut_table.shape}") + + # Load bench data + dist_histograms = {} # {distribution_name: [256] counts} for comparison plot + mode_histograms = {} # {mode_key: {dist_name: [256]}} for per-mode plots + + if args.bench_json: + print(f"\nLoading benchmark data from {args.bench_json}") + with open(args.bench_json) as f: + bench_data = json.load(f) + + if bench_data and isinstance(bench_data, list): + # Use first config entry for histogram mode visualization + entry = bench_data[0] + if "histograms" in entry: + mode_stats_data = entry["histograms"] + print(f" Histogram modes: {list(mode_stats_data.keys())}") + + # Extract raw_counts per distribution from bench entries + for entry in bench_data: + dist_name = entry.get("distribution", "unknown") + hist_data = entry.get("histogram", {}) + if "raw_counts" in hist_data and dist_name not in dist_histograms: + dist_histograms[dist_name] = hist_data["raw_counts"] + print(f" Loaded histogram for distribution: {dist_name}") + + # Extract per-mode histograms from histograms data + mode_histograms = {} # {mode_key: {dist_name: [256]}} + for entry in bench_data: + dist_name = entry.get("distribution", "unknown") + histograms_data = entry.get("histograms", {}) + for mode_key, mode_data in histograms_data.items(): + if isinstance(mode_data, dict) and "raw_counts" in mode_data: + if mode_key not in mode_histograms: + mode_histograms[mode_key] = {} + if dist_name not in mode_histograms[mode_key]: + mode_histograms[mode_key][dist_name] = mode_data["raw_counts"] + if mode_histograms: + print(f" Loaded per-mode histograms for: {sorted(mode_histograms.keys())}") + + # Load real-data histograms from .npy (calibrate_topk.py output) + real_counts = None + if args.real_histograms: + print(f"\nLoading real-data histograms from {args.real_histograms}") + real_hists = np.load(args.real_histograms) # [num_samples, 256] + real_counts = real_hists.sum(axis=0).tolist() # aggregate across samples + dist_histograms["real"] = real_counts + print(f" real_histograms shape: {real_hists.shape}, aggregated to [256]") + + # Generate plots + if raw_histograms is not None: + print("\nGenerating histogram plots...") + plot_bin_distribution(raw_histograms, args.output_dir, args.max_segments) + plot_bin_heatmap(raw_histograms, args.output_dir) + plot_summary_table(raw_histograms, mode_stats_data, args.output_dir) + + if lut_table is not None: + print("\nGenerating before/after mapping comparison...") + plot_before_after_mapping(raw_histograms, lut_table, args.output_dir, args.max_segments) + + if mode_stats_data is not None: + print("\nGenerating mode comparison plot...") + plot_mapping_mode_comparison(mode_stats_data, args.output_dir) + + if dist_histograms: + print("\nGenerating distribution comparison plot (raw/unmapped)...") + plot_distribution_comparison(dist_histograms, args.output_dir) + print("\nSaving bucket count table (raw/unmapped)...") + save_bucket_table(dist_histograms, args.output_dir) + + # Per-mode distribution plots and tables + if mode_histograms: + print("\nGenerating per-mode distribution plots and tables...") + for mode_key in sorted(mode_histograms): + mname = _mode_key_to_display(mode_key) + mode_num = _mode_key_to_number(mode_key) + mformula = MAPPING_MODE_FORMULAS.get(mode_num, mname) + # Include hyperparameter value in title if available + param_str = "" + if mode_stats_data and mode_key in mode_stats_data: + param = mode_stats_data[mode_key].get("param") + if param: + param_str = f" [{param}]" + mode_suffix = mname.lower().replace(" ", "_") + plot_distribution_comparison( + mode_histograms[mode_key], args.output_dir, + suffix=f"_{mode_suffix}", + title=f"Bucket Distribution — {mname}{param_str} ({mformula})", + ) + save_bucket_table( + mode_histograms[mode_key], args.output_dir, + filename=f"bucket_counts_{mode_suffix}.csv", + ) + + print(f"\nDone. All outputs saved to {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py new file mode 100644 index 00000000..9b37e32f --- /dev/null +++ b/benchmarks/autotune_topk_mapping.py @@ -0,0 +1,378 @@ +""" +Auto-tuner for TopK mapping hyperparameters. + +Sweeps all (mode, hyperparameter) combinations using the topk_hit_rate +kernel and ranks by Stage 1 resolution rate. + +Supports real-data score distributions via --real-histograms: loads the +raw_histograms.npy from calibration and synthesizes score tensors that +match the real bin distribution (by reversing the convert_to_uint8 mapping). + +Sweep grid: + - Mode 3 (power): p in [0.1, 0.25, 0.75, 0.9] + - Mode 6 (asinh): beta in [0.1, 0.5, 1, 2, 4] + - Mode 7 (log1p): alpha in [0.1, 0.5, 0.75, 1, 2, 4, 8] + - Baselines: mode 0 (none), mode 4 (log) + +Usage: + python benchmarks/autotune_topk_mapping.py --topk-val 30 --real-histograms calibration/raw_histograms.npy + python benchmarks/autotune_topk_mapping.py --topk-val 30 --output-json results.json +""" + +import argparse +import json +import math +from typing import List + +import numpy as np +import torch + +from bench_topk import make_topk_inputs, compute_histogram_stats +from vortex_torch_C import topk_profile_histogram + + + +SWEEP_GRID = { + # (mode, param_name, param_values) + 3: ("power_exp", [0.1, 0.25, 0.75, 0.9]), + 6: ("beta", [0.1, 0.5, 1.0, 2.0, 4.0]), + 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), +} +BASELINES = { + 0: ("none", 0.5), + 4: ("log", 0.5), +} +MODE_NAMES = { + 0: "none", + 3: "power", + 4: "log", + 6: "asinh", + 7: "log1p", +} + + +def _key_to_fp16(key: int) -> np.float16: + """Invert the convert_to_uint8 sign-flip for a single 16-bit key.""" + if key >= 0x8000: + bits = key & 0x7FFF + else: + bits = (~key) & 0xFFFF + return np.array([bits], dtype=np.uint16).view(np.float16)[0] + + +def build_bin_range_table(): + """Build per-bin (lo, hi) fp16 value tables by iterating all 65536 fp16 bit patterns. + + For each fp16 value, compute its bin via convert_to_uint8 logic, then track + the min/max fp16 value that lands in each bin. + + Returns: + (bin_lo, bin_hi): two [256] float32 arrays — the min and max fp16 values per bin. + """ + # Generate all 65536 fp16 bit patterns + all_bits = np.arange(65536, dtype=np.uint16) + all_fp16 = all_bits.view(np.float16) + + # Compute convert_to_uint8 for each: key = sign-flip, bin = key >> 8 + keys = np.where( + (all_bits & 0x8000).astype(bool), + (~all_bits).astype(np.uint16), + all_bits | np.uint16(0x8000), + ) + bins = (keys >> 8).astype(np.uint8) + + # Convert to float32 for min/max (fp16 has NaNs/Infs, filter them) + all_f32 = all_fp16.astype(np.float32) + valid = np.isfinite(all_f32) + + bin_lo = np.full(256, np.inf, dtype=np.float32) + bin_hi = np.full(256, -np.inf, dtype=np.float32) + + for b in range(256): + mask = (bins == b) & valid + if mask.any(): + vals = all_f32[mask] + bin_lo[b] = vals.min() + bin_hi[b] = vals.max() + + # For any bin with no valid fp16 values, fall back to midpoint + empty = bin_lo > bin_hi + for b in np.where(empty)[0]: + mid_key = (int(b) << 8) | 0x80 + val = float(_key_to_fp16(mid_key)) + bin_lo[b] = val + bin_hi[b] = val + + return bin_lo, bin_hi + + +def scores_from_histogram( + histogram: np.ndarray, + total_pages: int, + device: str = "cuda", +) -> torch.Tensor: + """Generate score tensor matching a real bin distribution. + + For each sampled bin, generates a uniform random fp16 value within the + bin's actual value range (not just the midpoint), so that mapped transforms + see diverse input values. + + Args: + histogram: [256] aggregated bin counts from calibration + total_pages: number of score entries to generate + device: torch device + + Returns: + scores: [total_pages, 1, 1] bfloat16 tensor + """ + bin_lo, bin_hi = build_bin_range_table() + + # Normalize histogram to probability distribution + counts = histogram.astype(np.float64) + total = counts.sum() + if total == 0: + return torch.zeros(total_pages, 1, 1, dtype=torch.bfloat16, device=device) + probs = counts / total + + # Sample bin indices according to the real distribution + bin_indices = np.random.choice(256, size=total_pages, p=probs) + + # Uniform random within each bin's fp16 range + lo = bin_lo[bin_indices] + hi = bin_hi[bin_indices] + rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) + scores_f32 = lo + rand * (hi - lo) + + # Convert float32 -> bfloat16 tensor + scores = torch.from_numpy(scores_f32).to(torch.bfloat16) + return scores.reshape(total_pages, 1, 1).to(device) + + +def make_real_inputs( + batch_size: int, + num_kv_heads: int, + seq_len: int, + page_size: int, + topk_val: int, + reserved_bos: int, + reserved_eos: int, + histogram: np.ndarray, + device: str = "cuda", +) -> dict: + """Build CSR-formatted inputs with scores matching a real histogram.""" + eff_batch_size = batch_size * num_kv_heads + num_pages_per_seg = math.ceil(seq_len / page_size) + total_dense_pages = eff_batch_size * num_pages_per_seg + sparse_per_seg = min(topk_val + reserved_bos + reserved_eos, num_pages_per_seg) + total_sparse_pages = eff_batch_size * sparse_per_seg + + dense_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device=device, + ) + sparse_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device=device, + ) + dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) + sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) + + x = scores_from_histogram(histogram, total_dense_pages, device=device) + + return { + "x": x, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + "eff_batch_size": eff_batch_size, + "num_pages_per_seg": num_pages_per_seg, + "sparse_per_seg": sparse_per_seg, + } + + +def run_sweep(args) -> List[dict]: + """Run all (mode, hyperparam) combos and return ranked results.""" + results = [] + + # Load real histogram if provided + real_histogram = None + if args.real_histograms: + raw = np.load(args.real_histograms) # [num_segments, 256] + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw # aggregate to [256] + + distributions = args.distributions + if real_histogram is not None: + distributions = ["real"] + + for dist in distributions: + if dist == "real": + inputs = make_real_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + histogram=real_histogram, + ) + else: + inputs = make_topk_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=dist, + ) + + eff_bs = inputs["eff_batch_size"] + + def evaluate(mode: int, power: float, label: str): + hists = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], + inputs["dense_kv_indptr"], + hists, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + None, # lut + None, # quantiles + ) + torch.cuda.synchronize() + stats = compute_histogram_stats(hists) + return { + "label": label, + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param": power, + "distribution": dist, + "gini": stats["gini"], + "max_mean_ratio": stats["max_mean_ratio"], + "num_nonzero_bins": stats["num_nonzero_bins"], + } + + # Baselines + for mode, (name, default_power) in BASELINES.items(): + r = evaluate(mode, default_power, f"m{mode}_{name}") + results.append(r) + + # Parametric sweep + for mode, (param_name, values) in SWEEP_GRID.items(): + mname = MODE_NAMES[mode] + for val in values: + label = f"m{mode}_{mname}_{param_name}={val}" + r = evaluate(mode, val, label) + results.append(r) + + return results + + +def print_table(results: List[dict]): + """Print ranked results as a formatted table.""" + # Sort by Gini ascending (lower = more uniform = better) + ranked = sorted(results, key=lambda r: r["gini"]) + + header = ( + f"{'Rank':>4s} {'Label':<35s} {'Dist':<12s} " + f"{'Gini':>6s} {'Max/Mean':>8s} {'NZBins':>6s}" + ) + print("\n" + "=" * len(header)) + print("TopK Mapping Auto-Tune Results (ranked by Gini, lower=better)") + print("=" * len(header)) + print(header) + print("-" * len(header)) + + for i, r in enumerate(ranked): + print( + f"{i+1:4d} {r['label']:<35s} {r['distribution']:<12s} " + f"{r['gini']:6.3f} " + f"{r['max_mean_ratio']:8.2f} {r['num_nonzero_bins']:6d}" + ) + + print("=" * len(header)) + if ranked: + best = ranked[0] + print( + f"\nBest overall: {best['label']} (dist={best['distribution']}) " + f"— gini={best['gini']:.3f}, max/mean={best['max_mean_ratio']:.2f}" + ) + + # Per-mode best summary (lowest gini per mode) + mode_best = {} + for r in results: + m = r["mode"] + if m not in mode_best or r["gini"] < mode_best[m]["gini"]: + mode_best[m] = r + + if mode_best: + print("\nBest per mode:") + for m in sorted(mode_best.keys()): + r = mode_best[m] + mname = MODE_NAMES.get(m, f"m{m}") + if m in SWEEP_GRID: + param_name = SWEEP_GRID[m][0] + param_str = f"{param_name}={r['param']}" + else: + param_str = "(baseline)" + print( + f" Mode {m:d} ({mname:>5s}): {param_str:<20s} " + f"gini={r['gini']:.3f} max/mean={r['max_mean_ratio']:.2f}" + ) + + +def main(): + parser = argparse.ArgumentParser( + description="Auto-tune TopK mapping hyperparameters" + ) + parser.add_argument("--batch-size", type=int, default=4) + parser.add_argument("--seq-len", type=int, default=4096) + parser.add_argument("--topk-val", type=int, default=30) + parser.add_argument("--num-kv-heads", type=int, default=2) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--reserved-bos", type=int, default=1) + parser.add_argument("--reserved-eos", type=int, default=2) + parser.add_argument( + "--distributions", nargs="+", + default=["normal"], + help="Score distributions for synthetic data (ignored when --real-histograms is set)", + ) + parser.add_argument( + "--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibration. When set, auto-tunes on " + "real score distribution instead of synthetic data.", + ) + parser.add_argument( + "--output-json", type=str, default=None, + help="Save results to JSON file", + ) + args = parser.parse_args() + + source = f"real ({args.real_histograms})" if args.real_histograms else f"synthetic ({args.distributions})" + print(f"Auto-tuning TopK mapping hyperparameters") + print(f" batch_size={args.batch_size}, seq_len={args.seq_len}, " + f"topk_val={args.topk_val}, num_kv_heads={args.num_kv_heads}") + print(f" score source: {source}") + n_parametric = sum(len(v) for _, v in SWEEP_GRID.values()) + n_dists = 1 if args.real_histograms else len(args.distributions) + print(f" sweep: {n_parametric} parametric + {len(BASELINES)} baselines " + f"= {n_parametric + len(BASELINES)} combos x {n_dists} dists") + + results = run_sweep(args) + print_table(results) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py new file mode 100644 index 00000000..ca039f2f --- /dev/null +++ b/benchmarks/bench_topk.py @@ -0,0 +1,587 @@ +""" +TopK kernel benchmarking suite. + +Measures kernel-level latency for the three topk variants (naive/CUB, +sglang with mapping modes) across configurable grid of batch sizes, +sequence lengths, topk values, and KV head counts. + +Usage: + python benchmarking/bench_topk.py --batch-sizes 4 8 --seq-lens 2048 4096 --topk-vals 30 --num-kv-heads 2 --repeat 50 +""" + +import argparse +import json +import math +import statistics +from typing import Dict, List, Optional + +import numpy as np +import torch + +from vortex_torch_C import topk_output, topk_output_sglang, topk_profile_histogram + +# Canonical mapping mode names — used in logs, tables, and plots +MAPPING_MODE_NAMES = { + 0: "None", + 1: "LUT CDF", + 2: "Quantile", + 3: "Power", + 4: "Log", + 5: "Index Cache", + 6: "Asinh", + 7: "Log1p", + 8: "Trunc8", +} + +MAPPING_MODE_FORMULAS = { + 0: "None (fp16 bucketing)", + 1: "LUT CDF (calibrated)", + 2: "Quantile (calibrated)", + 3: "Power: sign(x)*|x|^p", + 4: "Log: sign(x)*log(|x|+1)", + 5: "Index Cache", + 6: "Asinh: asinh(beta*x)", + 7: "Log1p: sign(x)*log1p(alpha*|x|)", + 8: "Trunc8: bf16 upper-8-bit bucketing", +} + + +def make_topk_inputs( + batch_size: int, + num_kv_heads: int, + seq_len: int, + page_size: int, + topk_val: int, + reserved_bos: int, + reserved_eos: int, + score_dtype: torch.dtype, + distribution: str = "normal", + device: str = "cuda", +) -> dict: + """Synthesize realistic CSR-formatted paged attention inputs.""" + eff_batch_size = batch_size * num_kv_heads + num_pages_per_seg = math.ceil(seq_len / page_size) + total_dense_pages = eff_batch_size * num_pages_per_seg + sparse_per_seg = min(topk_val + reserved_bos + reserved_eos, num_pages_per_seg) + total_sparse_pages = eff_batch_size * sparse_per_seg + + dense_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device=device, + ) + sparse_kv_indptr = torch.arange( + 0, (eff_batch_size + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device=device, + ) + dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) + sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) + + # Generate scores with the requested distribution + if distribution == "normal": + x = torch.randn(total_dense_pages, 1, 1, device=device) + elif distribution == "lognormal": + x = torch.randn(total_dense_pages, 1, 1, device=device).exp() + elif distribution == "uniform": + x = torch.rand(total_dense_pages, 1, 1, device=device) + elif distribution == "bucket_uniform": + # Uniform across all 256 fp16 radix buckets. + # Random uint16 bit patterns → interpret as fp16. + # Bucket = upper 8 bits of sign-flipped fp16, so random bits → uniform buckets. + raw_bits = torch.randint(0, 65536, (total_dense_pages,), dtype=torch.int32, device=device) + # Exclude fp16 NaN/Inf (exponent=31, i.e. |bits| >= 0x7C00) + abs_bits = raw_bits & 0x7FFF + raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 # → ±0 + # Reinterpret int16 bits as fp16, then widen to float32 + x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1) + else: + raise ValueError(f"Unknown distribution: {distribution}") + + x = x.to(score_dtype) + + return { + "x": x, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + "eff_batch_size": eff_batch_size, + "num_pages_per_seg": num_pages_per_seg, + "sparse_per_seg": sparse_per_seg, + } + + +def bench_kernel(kernel_fn, args, warmup: int = 10, repeat: int = 100) -> dict: + """Time a kernel with CUDA events, return latency stats in ms.""" + for _ in range(warmup): + kernel_fn(*args) + torch.cuda.synchronize() + + start_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + end_events = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + for i in range(repeat): + start_events[i].record() + kernel_fn(*args) + end_events[i].record() + torch.cuda.synchronize() + + times = [s.elapsed_time(e) for s, e in zip(start_events, end_events)] + return { + "mean_ms": statistics.mean(times), + "median_ms": statistics.median(times), + "std_ms": statistics.stdev(times) if len(times) > 1 else 0.0, + "min_ms": min(times), + "max_ms": max(times), + } + + +def compute_histogram_stats(histograms: torch.Tensor) -> dict: + """Compute bin distribution statistics from histogram tensor [B, 256].""" + h = histograms.float() + # Aggregate across batch dimension + h_sum = h.sum(dim=0) # [256] + nonzero_bins = h_sum[h_sum > 0] + if len(nonzero_bins) == 0: + return { + "max_mean_ratio": 0.0, "std": 0.0, "gini": 0.0, + "num_nonzero_bins": 0, "entropy": 0.0, "effective_bins": 0.0, + } + + mean_val = nonzero_bins.mean().item() + max_val = nonzero_bins.max().item() + std_val = nonzero_bins.std().item() if len(nonzero_bins) > 1 else 0.0 + + # Gini coefficient + sorted_bins = nonzero_bins.sort().values + n = len(sorted_bins) + index = torch.arange(1, n + 1, device=sorted_bins.device, dtype=torch.float32) + gini = (2.0 * (index * sorted_bins).sum() / (n * sorted_bins.sum()) - (n + 1) / n).item() + + # Shannon entropy (base-2) + p = nonzero_bins / nonzero_bins.sum() + entropy = -(p * p.log2()).sum().item() + # Effective number of bins: 2^entropy + effective_bins = 2 ** entropy + + return { + "max_mean_ratio": max_val / mean_val if mean_val > 0 else 0.0, + "std": std_val, + "gini": max(0.0, gini), + "num_nonzero_bins": int(len(nonzero_bins)), + "entropy": entropy, + "effective_bins": effective_bins, + } + + +NUM_HISTOGRAM_BINS = 256 + + +def _histogram_target_pages(pages_per_seg: int, min_samples_per_bin: int = 512) -> int: + """Compute adaptive page count for statistically reliable histograms. + + With 256 radix bins, each bin needs enough samples for stable gini / + max-mean statistics. Returns a total page count rounded up to a full + segment boundary so every segment contributes equally. + """ + min_pages = min_samples_per_bin * NUM_HISTOGRAM_BINS + return math.ceil(min_pages / pages_per_seg) * pages_per_seg + + +def _load_autotune_powers(path: str) -> Dict[int, float]: + """Extract best per-mode power from autotune JSON. + + Ranks by res_rate_mean (higher=better) if present, else by gini (lower=better). + Returns {mode: best_power}, e.g. {3: 0.25, 6: 1.0, 7: 2.0}. + """ + with open(path) as f: + data = json.load(f) + + has_res_rate = any("res_rate_mean" in r for r in data) + + best: Dict[int, dict] = {} + for r in data: + m = r.get("mode") + if m not in (3, 6, 7): + continue + if has_res_rate: + score = r.get("res_rate_mean", 0.0) + is_better = m not in best or score > best[m]["_score"] + else: + score = r.get("gini", 1.0) + is_better = m not in best or score < best[m]["_score"] + if is_better: + best[m] = {"param": r["param"], "_score": score} + + return {m: v["param"] for m, v in best.items()} + + +def _resolve_mode_power(args, mode: int) -> float: + """Return the power/beta/alpha for a parametric mapping mode. + + Priority: per-mode CLI flag > autotune JSON > global --mapping-power. + """ + per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7} + if mode in per_mode_flag and per_mode_flag[mode] is not None: + return per_mode_flag[mode] + if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: + return args._autotune_powers[mode] + return args.mapping_power + + +def run_benchmark(args) -> List[dict]: + """Run the full benchmark sweep and return results.""" + # Load autotune results if provided + if args.autotune_json: + args._autotune_powers = _load_autotune_powers(args.autotune_json) + print(f"Loaded autotune best powers: {args._autotune_powers}") + else: + args._autotune_powers = {} + + dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} + score_dtype = dtype_map[args.score_dtype] + + # Load real histogram if provided + real_histogram = None + _scores_from_histogram = None + if args.real_histograms: + from autotune_topk_mapping import scores_from_histogram + _scores_from_histogram = scores_from_histogram + raw = np.load(args.real_histograms) + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + + # Extend distributions with "real" if calibration data is provided + distributions = list(args.distributions) + if real_histogram is not None: + distributions.append("real") + args.distributions = distributions + + # Print GPU info + gpu_name = torch.cuda.get_device_name(0) + gpu_props = torch.cuda.get_device_properties(0) + print(f"TopK Kernel Benchmark Results") + print(f"GPU: {gpu_name} | SM count: {gpu_props.multi_processor_count}") + print(f"Score dtype: {args.score_dtype} | Warmup: {args.warmup} | Repeat: {args.repeat}") + print("=" * 90) + + # Load optional LUT / quantiles + mapping_lut = None + mapping_quantiles = None + if args.lut_path: + lut_np = np.load(args.lut_path).astype(np.uint8) + mapping_lut = torch.from_numpy(lut_np).cuda() + if args.quantiles_path: + q_np = np.load(args.quantiles_path).astype(np.float32) + mapping_quantiles = torch.from_numpy(q_np).cuda() + + # Build kernel list + all_kernels = { + "naive": "naive", + "sglang_m0": "sglang_m0", + "sglang_m3": "sglang_m3", + "sglang_m4": "sglang_m4", + "sglang_m6": "sglang_m6", + "sglang_m7": "sglang_m7", + "sglang_m8": "sglang_m8", + } + if mapping_lut is not None: + all_kernels["sglang_m1"] = "sglang_m1" + if mapping_quantiles is not None: + all_kernels["sglang_m2"] = "sglang_m2" + + if args.filter_kernels: + all_kernels = {k: v for k, v in all_kernels.items() if k in args.filter_kernels} + + # Naive kernel only supports bf16 + if score_dtype != torch.bfloat16 and "naive" in all_kernels: + print(f"Note: naive kernel only supports bfloat16, skipping for {args.score_dtype}") + del all_kernels["naive"] + + all_results = [] + + for bs in args.batch_sizes: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + for num_kv_heads in args.num_kv_heads: + for dist in args.distributions: + if dist == "real" and real_histogram is not None: + inputs = make_topk_inputs( + batch_size=bs, + num_kv_heads=num_kv_heads, + seq_len=seq_len, + page_size=args.page_size, + topk_val=topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=score_dtype, + distribution="normal", + ) + # Replace scores with real-distribution scores + total_dense = inputs["eff_batch_size"] * inputs["num_pages_per_seg"] + inputs["x"] = _scores_from_histogram( + real_histogram, total_dense, device="cuda", + ) + else: + inputs = make_topk_inputs( + batch_size=bs, + num_kv_heads=num_kv_heads, + seq_len=seq_len, + page_size=args.page_size, + topk_val=topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=score_dtype, + distribution=dist, + ) + + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + + config_str = ( + f"bs={bs} | seq={seq_len} | topk={topk_val} | " + f"heads={num_kv_heads} | pages/seg={pages_per_seg} | dist={dist}" + ) + print(f"\n{config_str}") + + config_results = { + "batch_size": bs, + "seq_len": seq_len, + "topk_val": topk_val, + "num_kv_heads": num_kv_heads, + "distribution": dist, + "eff_batch_size": eff_bs, + "pages_per_seg": pages_per_seg, + "kernels": {}, + } + + for kernel_name in all_kernels: + # Reset sparse indices each run + inputs["sparse_kv_indices"].zero_() + + if kernel_name == "naive": + # topk_output: (x, dense_indptr, dense_indices, sparse_indptr, sparse_indices, ...) + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indptr"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + ) + result = bench_kernel(topk_output, call_args, args.warmup, args.repeat) + else: + # Parse mapping mode from kernel name + mode = int(kernel_name.split("_m")[1]) + extra_kwargs = {} + if mode == 1: + extra_kwargs["mapping_lut"] = mapping_lut + elif mode == 2: + extra_kwargs["mapping_quantiles"] = mapping_quantiles + + power = _resolve_mode_power(args, mode) if mode in (3, 6, 7) else 0.5 + + # topk_output_sglang: (x, dense_indptr, sparse_indptr, dense_indices, sparse_indices, ...) + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + extra_kwargs.get("mapping_lut", None), + extra_kwargs.get("mapping_quantiles", None), + ) + result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) + + if kernel_name == "naive": + label = "naive" + else: + m = int(kernel_name.split("_m")[1]) + mname = MAPPING_MODE_NAMES.get(m, f'm{m}') + if m in (3, 6, 7): + pname = {3: "p", 6: "beta", 7: "alpha"}[m] + label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)})" + else: + label = f"sglang {mname}" + print( + f" {label:<30s}: {result['median_ms']:.4f}ms (median) " + f"\u00b1 {result['std_ms']:.4f}ms " + f"[min={result['min_ms']:.4f}, max={result['max_ms']:.4f}]" + ) + config_results["kernels"][kernel_name] = result + + # Histogram analysis + if args.histogram: + # Build a separate (potentially larger) dataset for histogram profiling + target_pages = (args.histogram_pages + if args.histogram_pages is not None + else _histogram_target_pages(pages_per_seg)) + current_pages = eff_bs * pages_per_seg + if target_pages > current_pages: + hist_bs = math.ceil(target_pages / (num_kv_heads * pages_per_seg)) + if dist == "real" and real_histogram is not None: + hist_inputs = make_topk_inputs( + batch_size=hist_bs, num_kv_heads=num_kv_heads, + seq_len=seq_len, page_size=args.page_size, + topk_val=topk_val, reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, score_dtype=score_dtype, + distribution="normal", + ) + total_hist_dense = hist_inputs["eff_batch_size"] * hist_inputs["num_pages_per_seg"] + hist_inputs["x"] = _scores_from_histogram(real_histogram, total_hist_dense, device="cuda") + else: + hist_inputs = make_topk_inputs( + batch_size=hist_bs, num_kv_heads=num_kv_heads, + seq_len=seq_len, page_size=args.page_size, + topk_val=topk_val, reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, score_dtype=score_dtype, + distribution=dist, + ) + hist_eff_bs = hist_inputs["eff_batch_size"] + actual_pages = hist_eff_bs * pages_per_seg + print( + f" histogram dataset : {actual_pages} pages " + f"(upscaled from {current_pages} for statistical reliability)" + ) + else: + hist_inputs = inputs + hist_eff_bs = eff_bs + actual_pages = current_pages + print(f" histogram dataset : {actual_pages} pages") + + # Raw unmapped histogram + histograms = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + hist_inputs["x"], + hist_inputs["dense_kv_indptr"], + histograms, + hist_eff_bs, + args.reserved_bos, + args.reserved_eos, + ) + hstats = compute_histogram_stats(histograms) + hstats["raw_counts"] = histograms.sum(dim=0).tolist() # [256] ints + config_results["histogram"] = hstats + print( + f" histogram stats : max/mean={hstats['max_mean_ratio']:.2f} " + f"gini={hstats['gini']:.3f} " + f"nonzero_bins={hstats['num_nonzero_bins']}/256" + ) + + # Per-mode histogram analysis + modes_to_test = [0, 3, 4, 6, 7, 8] + if mapping_lut is not None: + modes_to_test.append(1) + if mapping_quantiles is not None: + modes_to_test.append(2) + modes_to_test.sort() + + histograms_results = {} + print(f" --- histogram by mapping mode ---") + for mode in modes_to_test: + mode_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") + + extra_lut = mapping_lut if mode == 1 else None + extra_q = mapping_quantiles if mode == 2 else None + power = _resolve_mode_power(args, mode) if mode in (3, 6, 7) else 0.5 + + topk_profile_histogram( + hist_inputs["x"], + hist_inputs["dense_kv_indptr"], + mode_hists, + hist_eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + extra_lut, + extra_q, + ) + torch.cuda.synchronize() + + mode_stats = compute_histogram_stats(mode_hists) + mode_stats["raw_counts"] = mode_hists.sum(dim=0).tolist() + mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") + mformula = MAPPING_MODE_FORMULAS.get(mode, mname) + mode_stats["name"] = mname + mode_stats["formula"] = mformula + if mode in (3, 6, 7): + pname = {3: "p", 6: "beta", 7: "alpha"}[mode] + mode_stats["param"] = f"{pname}={power}" + histograms_results[f"mode_{mode}_{mname}"] = mode_stats + if mode in (3, 6, 7): + pname = {3: "p", 6: "beta", 7: "alpha"}[mode] + display_name = f"{mname} ({pname}={power})" + else: + display_name = mname + print( + f" {display_name:<22s} (mode {mode}): " + f"gini={mode_stats['gini']:.3f} " + f"max/mean={mode_stats['max_mean_ratio']:.2f} " + f"nonzero_bins={mode_stats['num_nonzero_bins']}/256 " + f"eff_bins={mode_stats['effective_bins']:.1f} " + f"entropy={mode_stats['entropy']:.2f}" + ) + config_results["histograms"] = histograms_results + + all_results.append(config_results) + + return all_results + + +def main(): + parser = argparse.ArgumentParser(description="TopK kernel benchmark suite") + parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 4, 8, 16, 32, 64]) + parser.add_argument("--seq-lens", nargs="+", type=int, default=[1024, 2048, 4096, 8192]) + parser.add_argument("--topk-vals", nargs="+", type=int, default=[16, 30, 64]) + parser.add_argument("--num-kv-heads", nargs="+", type=int, default=[2, 4, 8]) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--reserved-bos", type=int, default=1) + parser.add_argument("--reserved-eos", type=int, default=2) + parser.add_argument("--score-dtype", choices=["bfloat16", "float32"], default="bfloat16") + parser.add_argument("--distributions", nargs="+", default=["normal", "lognormal", "uniform"]) + parser.add_argument("--warmup", type=int, default=10) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--mapping-power", type=float, default=0.5, + help="Global fallback power parameter for parametric modes (default: 0.5)") + parser.add_argument("--mapping-power-3", type=float, default=None, + help="Power exponent p for mode 3 (overrides --mapping-power)") + parser.add_argument("--mapping-power-6", type=float, default=None, + help="Beta for mode 6 asinh (overrides --mapping-power)") + parser.add_argument("--mapping-power-7", type=float, default=None, + help="Alpha for mode 7 log1p (overrides --mapping-power)") + parser.add_argument("--autotune-json", type=str, default=None, + help="Path to autotune_results.json — extracts best per-mode hyperparameters " + "(overrides --mapping-power for modes 3/6/7)") + parser.add_argument("--lut-path", type=str, default=None, help="Path to .npy uint8[256] LUT for mode=1") + parser.add_argument("--quantiles-path", type=str, default=None, help="Path to .npy float32[256] for mode=2") + parser.add_argument("--output-json", type=str, default=None, help="Save results to JSON file") + parser.add_argument("--filter-kernels", nargs="+", default=None, + help="Only run specific kernels: naive, sglang_m0, sglang_m3, sglang_m4") + parser.add_argument("--histogram", action="store_true", help="Collect and report bin distribution statistics") + parser.add_argument("--histogram-pages", type=int, default=None, + help="Total pages for histogram profiling. Default: adaptive " + "(512 samples/bin × 256 bins, rounded to segment boundary). " + "Only used when --histogram is set.") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to .npy raw_histograms from calibration (adds 'real' distribution)") + + args = parser.parse_args() + results = run_benchmark(args) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py new file mode 100644 index 00000000..4c861161 --- /dev/null +++ b/benchmarks/calibrate_topk.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python3 +""" +Offline calibration for TopK mapping modes 1 (LUT CDF) and 2 (quantile). + +Runs the model on real data with hit-rate profiling enabled, collects score +histograms from the topk_sglang kernel, and generates: + - lut.npy : uint8[256] CDF-equalized LUT for mapping mode 1 + - quantiles.npy: float32[256] quantile breakpoints for mapping mode 2 + +Usage: + python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 --mem 0.7 \ + --output-dir calibration_output/ +""" + +import argparse +import json +import os +import sys + +import numpy as np + +# Add project root to path so we can import from benchmarks/ +sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) + +from benchmarks.profile_topk_distribution import ( + compute_lut_from_histogram, + generate_tables_from_histograms, +) + + +def main(): + parser = argparse.ArgumentParser( + description="Offline calibration for TopK mapping modes 1 & 2" + ) + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-1.7B") + parser.add_argument("--topk-val", type=int, default=30) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--mem", type=float, default=0.7) + parser.add_argument("--kv-cache-dtype", type=str, default="auto") + parser.add_argument("--topk-type", type=str, default="sglang") + parser.add_argument("--num-prompts", type=int, default=16, + help="Number of calibration prompts to use (default: 16)") + parser.add_argument("--output-dir", type=str, default="calibration_output/") + parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") + args = parser.parse_args() + + # Lazy imports to avoid slow startup when just checking --help + import sglang as sgl + import torch + import vortex_torch + + os.makedirs(args.output_dir, exist_ok=True) + + print(f"[calibrate] Launching engine with hit-rate profiling enabled...") + llm = sgl.Engine( + model_path=args.model_name, + disable_cuda_graph=True, + page_size=args.page_size, + vortex_topk_val=args.topk_val, + disable_overlap_schedule=True, + attention_backend="flashinfer", + enable_vortex_sparsity=True, + vortex_page_reserved_bos=1, + vortex_page_reserved_eos=2, + vortex_layers_skip=list(range(1)), + vortex_module_name=args.vortex_module_name, + vortex_max_seq_lens=12288, + mem_fraction_static=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + vortex_topk_type=args.topk_type, + vortex_topk_mapping_mode=0, # Use mode 0 during calibration + vortex_topk_histogram=True, # Enable histogram collection + ) + + # Clear any residual histograms in the worker process + llm.clear_topk_histograms() + + # Load calibration prompts + prompts_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "..", "examples", "amc23.jsonl" + ) + with open(prompts_path, "r", encoding="utf-8") as f: + all_requests = [json.loads(line) for line in f] + + # Use up to num_prompts + requests = all_requests[:args.num_prompts] + prompts = [req["prompt"] for req in requests] + + print(f"[calibrate] Running {len(prompts)} calibration prompts...") + sampling_params = { + "temperature": 0.6, + "top_p": 0.95, + "top_k": 20, + "max_new_tokens": 8192, + } + llm.generate(prompts, sampling_params) + + # Collect histograms via RPC from worker process + histograms = llm.get_topk_histograms() + print(f"[calibrate] Collected {len(histograms)} histogram batches") + + if len(histograms) == 0: + print("[calibrate] ERROR: No histograms collected. " + "Ensure topk_type='sglang' and vortex_topk_histogram=True.", + file=sys.stderr) + llm.shutdown() + sys.exit(1) + + # Stack all histograms: each is [eff_bs, 256], concatenate along batch dim + all_hists = torch.cat(histograms, dim=0).numpy() # [total_samples, 256] + print(f"[calibrate] Total histogram samples: {all_hists.shape[0]}") + + # --- Generate LUT (mode 1) --- + # Aggregate histogram across all samples + avg_histogram = all_hists.mean(axis=0) + lut = compute_lut_from_histogram(avg_histogram) + lut_path = os.path.join(args.output_dir, "lut.npy") + np.save(lut_path, lut) + print(f"[calibrate] Saved LUT to {lut_path} (shape={lut.shape}, dtype={lut.dtype})") + + # --- Generate quantiles (mode 2) --- + # Use bin centers as proxy scores weighted by histogram counts + bin_centers = np.arange(256, dtype=np.float32) + # Expand histogram counts into a weighted score distribution + total_counts = avg_histogram.astype(np.float64) + total = total_counts.sum() + if total > 0: + cdf = np.cumsum(total_counts) / total + # Invert CDF to get quantile breakpoints in [0, 255] space + percentiles = np.linspace(0, 1, 256) + quantiles = np.interp(percentiles, cdf, bin_centers).astype(np.float32) + else: + quantiles = bin_centers.copy() + + quantiles_path = os.path.join(args.output_dir, "quantiles.npy") + np.save(quantiles_path, quantiles) + print(f"[calibrate] Saved quantiles to {quantiles_path} (shape={quantiles.shape}, dtype={quantiles.dtype})") + + # Save raw histograms for debugging + raw_path = os.path.join(args.output_dir, "raw_histograms.npy") + np.save(raw_path, all_hists) + print(f"[calibrate] Saved raw histograms to {raw_path} (shape={all_hists.shape})") + + # Cleanup + llm.clear_topk_histograms() + llm.shutdown() + print(f"[calibrate] Done. Output files in {args.output_dir}/") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/greedy_layer_search.py b/benchmarks/greedy_layer_search.py new file mode 100644 index 00000000..118ac454 --- /dev/null +++ b/benchmarks/greedy_layer_search.py @@ -0,0 +1,117 @@ +"""Greedy forward-selection of layers whose indexer can be skipped (index cache). + +Usage (from repo root): + cd examples && python ../benchmarks/greedy_layer_search.py \ + --model-name Qwen/Qwen3-1.7B --topk-val 30 --threshold 0.95 \ + --trials 1 --num-layers 28 --mem 0.7 + +The script prints progress to stderr and outputs the final selected layer list +(as a Python list literal) on the **last line of stdout** so callers can parse it. +""" + +import argparse +import os +import sys + +# Add examples/ to path so we can import verify_algos +_examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "examples") +sys.path.insert(0, _examples_dir) + +from verify_algo import verify_algos # noqa: E402 + + +def _evaluate(shared_layers, args): + """Run verify_algos with the given shared layers and return pass@trials accuracy.""" + summary = verify_algos( + trials=args.trials, + topk_val=args.topk_val, + page_size=args.page_size, + vortex_module_name=args.vortex_module_name, + model_name=args.model_name, + sparse_attention=True, + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, + topk_mapping_mode=0, + topk_mapping_power=args.topk_mapping_power, + index_cache_shared_layers=sorted(shared_layers) if shared_layers else None, + disable_cuda_graph=True, + ) + acc_key = f"pass@{args.trials}" + return summary[acc_key] + + +def greedy_search(args): + # Ensure we're in examples/ so amc23.jsonl relative path works + os.chdir(_examples_dir) + + candidates = list(range(1, args.num_layers)) + + # Baseline: no shared layers + print("Evaluating baseline (no shared layers)...", file=sys.stderr) + baseline_acc = _evaluate([], args) + print(f"Baseline accuracy: {baseline_acc:.4f}", file=sys.stderr) + + threshold = args.threshold + shared_set = [] + + while candidates: + best_layer = None + best_acc = -1.0 + + for layer in candidates: + trial_set = shared_set + [layer] + print(f" Trying shared_set={sorted(trial_set)} ...", file=sys.stderr, end=" ") + acc = _evaluate(trial_set, args) + print(f"acc={acc:.4f}", file=sys.stderr) + + if acc > best_acc: + best_acc = acc + best_layer = layer + + if best_acc >= threshold * baseline_acc: + shared_set.append(best_layer) + candidates.remove(best_layer) + print( + f"Added layer {best_layer} (acc={best_acc:.4f} >= " + f"{threshold * baseline_acc:.4f}). Current set: {sorted(shared_set)}", + file=sys.stderr, + ) + else: + print( + f"Stopping: best candidate layer {best_layer} acc={best_acc:.4f} < " + f"{threshold * baseline_acc:.4f}", + file=sys.stderr, + ) + break + + result = sorted(shared_set) + print(f"Final shared layers: {result}", file=sys.stderr) + # Last stdout line: parseable Python list + print(result) + return result + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Greedy forward-selection of index-cache shared layers." + ) + parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-1.7B") + parser.add_argument("--topk-val", type=int, default=30) + parser.add_argument("--page-size", type=int, default=16) + parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") + parser.add_argument("--mem", type=float, default=0.8) + parser.add_argument("--kv-cache-dtype", type=str, default="auto") + parser.add_argument("--topk-type", type=str, default="naive") + parser.add_argument("--topk-mapping-power", type=float, default=0.5) + parser.add_argument("--threshold", type=float, default=0.95, + help="Minimum accuracy ratio vs baseline to keep adding layers (default: 0.95).") + parser.add_argument("--trials", type=int, default=1) + parser.add_argument("--num-layers", type=int, default=28, + help="Total number of model layers (default: 28 for Qwen3-1.7B).") + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + greedy_search(args) diff --git a/benchmarks/profile_topk_distribution.py b/benchmarks/profile_topk_distribution.py new file mode 100644 index 00000000..bea911b0 --- /dev/null +++ b/benchmarks/profile_topk_distribution.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +""" +Profile TopK bin distribution and generate mapping tables. + +This script collects Stage 1 (8-bit coarse histogram) distributions from +the topk_sglang kernel and generates LUT/quantile mapping tables that +can be used to equalize the bin distribution for improved sorting efficiency. + +Usage: + python scripts/profile_topk_distribution.py \ + --model-name Qwen/Qwen3-1.7B \ + --output mapping_tables.npz \ + --num-prompts 32 \ + --mem 0.7 + +Output (.npz): + lut_tables: [num_collected, 256] uint8 - CDF-equalized LUT per sample + quantile_tables: [num_collected, 256] float32 - quantile breakpoints per sample + raw_histograms: [num_collected, 256] int32 - raw bin histograms +""" + +import argparse +import numpy as np +import torch + + +def compute_lut_from_histogram(histogram: np.ndarray) -> np.ndarray: + """Compute CDF-equalized LUT from a 256-bin histogram. + + Args: + histogram: [256] int array of bin counts + + Returns: + lut: [256] uint8 array where lut[i] = floor(CDF(i) * 255) + """ + cdf = np.cumsum(histogram).astype(np.float64) + total = cdf[-1] + if total == 0: + return np.arange(256, dtype=np.uint8) + cdf_normalized = cdf / total + lut = np.floor(cdf_normalized * 255).astype(np.uint8) + return lut + + +def compute_quantiles_from_scores(scores: np.ndarray, num_quantiles: int = 256) -> np.ndarray: + """Compute quantile breakpoints from raw float scores. + + Args: + scores: 1D array of float scores + num_quantiles: number of quantile bins (default 256) + + Returns: + quantiles: [num_quantiles] float32 array of sorted breakpoints + """ + if len(scores) == 0: + return np.zeros(num_quantiles, dtype=np.float32) + percentiles = np.linspace(0, 100, num_quantiles) + quantiles = np.percentile(scores, percentiles).astype(np.float32) + return quantiles + + +def generate_tables_from_histograms(histograms: np.ndarray) -> dict: + """Generate LUT and quantile tables from collected histograms. + + Args: + histograms: [N, 256] int32 array of bin histograms + + Returns: + dict with 'lut_tables' and 'aggregate_lut' + """ + N = histograms.shape[0] + lut_tables = np.zeros((N, 256), dtype=np.uint8) + + for i in range(N): + lut_tables[i] = compute_lut_from_histogram(histograms[i]) + + # Aggregate: average histogram across all samples + avg_histogram = histograms.mean(axis=0) + aggregate_lut = compute_lut_from_histogram(avg_histogram) + + return { + 'lut_tables': lut_tables, + 'aggregate_lut': aggregate_lut, + } + + +def main(): + parser = argparse.ArgumentParser( + description="Profile TopK bin distribution and generate mapping tables") + parser.add_argument("--output", type=str, default="mapping_tables.npz", + help="Output .npz file path") + parser.add_argument("--histograms-input", type=str, default=None, + help="Load pre-collected histograms from .npy file instead of running inference") + parser.add_argument("--scores-input", type=str, default=None, + help="Load pre-collected raw scores from .npy for quantile computation") + args = parser.parse_args() + + results = {} + + if args.histograms_input: + print(f"Loading histograms from {args.histograms_input}") + histograms = np.load(args.histograms_input) + if histograms.ndim == 1: + histograms = histograms.reshape(1, -1) + results['raw_histograms'] = histograms + + tables = generate_tables_from_histograms(histograms) + results.update(tables) + + if args.scores_input: + print(f"Loading scores from {args.scores_input}") + scores = np.load(args.scores_input) + quantiles = compute_quantiles_from_scores(scores.flatten()) + results['quantile_table'] = quantiles + + if not results: + print("No input provided. Use --histograms-input or --scores-input.") + print("\nTo collect histograms, use the topk_profile_histogram() function from vortex_torch_C:") + print(" from vortex_torch_C import topk_profile_histogram") + print(" histograms = torch.zeros(eff_batch_size, 256, dtype=torch.int32, device='cuda')") + print(" topk_profile_histogram(scores, dense_kv_indptr, histograms, eff_batch_size, bos, eos)") + print(" np.save('histograms.npy', histograms.cpu().numpy())") + return + + np.savez(args.output, **results) + print(f"Saved mapping tables to {args.output}") + for key, val in results.items(): + print(f" {key}: shape={val.shape}, dtype={val.dtype}") + + +if __name__ == "__main__": + main() diff --git a/csrc/register.cc b/csrc/register.cc index 532fcdfa..00674743 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -8,7 +8,24 @@ PYBIND11_MODULE(vortex_torch_C, m){ m.def("Chunkwise_NH2HN_Transpose", &Chunkwise_NH2HN_Transpose); m.def("Chunkwise_HN2NH_Transpose", &Chunkwise_HN2NH_Transpose); m.def("topk_output", &topk_output); - m.def("topk_output_sglang", &topk_output_sglang); + m.def("topk_output_sglang", &topk_output_sglang, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); + m.def("topk_profile_histogram", &topk_profile_histogram, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("histograms"), py::arg("eff_batch_size"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index b81168bb..d4f2d8b4 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -95,7 +95,24 @@ const int64_t eff_batch_size, const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t max_seq_lengths +const int64_t max_seq_lengths, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); + +void topk_profile_histogram( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& histograms, +const int64_t eff_batch_size, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt ); void sglang_plan_decode_fa3( diff --git a/csrc/topk.cu b/csrc/topk.cu index 3aa49b98..70d2000a 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -117,8 +117,8 @@ const int page_reserved_eos) void topk_output( const at::Tensor& x, const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, const at::Tensor& dense_kv_indices, +const at::Tensor& sparse_kv_indptr, at::Tensor& sparse_kv_indices, const int64_t eff_batch_size, const int64_t topk_val, diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh new file mode 100644 index 00000000..e3fe3a73 --- /dev/null +++ b/csrc/topk_mapping.cuh @@ -0,0 +1,148 @@ +#pragma once +#include +#include +#include + +// ============================================================ +// TopK bucket-sort distribution mapping strategies +// +// These transforms remap float scores before Stage 1's 8-bit +// histogram binning, aiming for a more uniform distribution +// across the 256 coarse bins. Stage 2 refinement still uses +// convert_to_uint32() on raw floats, so correctness is preserved. +// +// Modes 3/4/6/7 use a data-adaptive linear mapping to [0,255] +// instead of fp16 bit-pattern bucketing, guaranteeing full +// bucket utilization regardless of value range. +// ============================================================ + +enum TopKMappingMode { + MAPPING_NONE = 0, // Original convert_to_uint8 behavior + MAPPING_LUT_CDF = 1, // LUT-based CDF equalization + MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping + MAPPING_POWER = 3, // Monotonic power transform + MAPPING_LOG = 4, // Log transform + MAPPING_INDEX_CACHE = 5, // Sentinel: reuse previous layer's indices (Python-level skip) + MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp + MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing +}; + +struct TopKMappingParams { + int mode; // TopKMappingMode + float power_exp; // For MAPPING_POWER (default 0.5) + const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr + const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr +}; + +// NOTE: convert_to_uint8() must be defined before including this header. +// It is defined in topk_sglang.cu within the anonymous namespace. + +// ---- Individual transform functions (return float, no bucketing) ---- + +__device__ __forceinline__ float transform_power(float x, float p) { + return copysignf(__powf(fabsf(x), p), x); +} + +__device__ __forceinline__ float transform_log(float x) { + return copysignf(__logf(fabsf(x) + 1.0f), x); +} + +__device__ __forceinline__ float transform_asinh(float x, float beta) { + return asinhf(beta * x); +} + +__device__ __forceinline__ float transform_log1p(float x, float alpha) { + return copysignf(log1pf(alpha * fabsf(x)), x); +} + +// ---- Transform dispatcher (returns float, no bucketing) ---- + +__device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { + switch (params.mode) { + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + default: return x; + } +} + +// ---- Linear bucketing for transform modes ---- + +__device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_min, float inv_range) { + int bin = __float2int_rd((val - range_min) * inv_range); + return static_cast(min(max(bin, 0), 255)); +} + +// ---- BF16 upper-8-bit bucketing (mode 8) ---- + +__device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { + __nv_bfloat16 bf = __float2bfloat16_rn(x); + uint16_t bits = __bfloat16_as_ushort(bf); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +// ---- Non-transform mapping functions (unchanged) ---- + +// LUT-based CDF equalization: lut[original_bin] -> equalized_bin +__device__ __forceinline__ uint8_t map_lut_cdf(float x, const uint8_t* __restrict__ s_lut) { + return s_lut[convert_to_uint8(x)]; +} + +// Quantile mapping: binary search over 256 sorted thresholds +__device__ __forceinline__ uint8_t map_quantile(float x, const float* __restrict__ s_quantiles) { + // Binary search: find largest index i such that x >= s_quantiles[i] + // s_quantiles is sorted ascending, length 256 + int lo = 0, hi = 255; +#pragma unroll 8 + for (int iter = 0; iter < 8; ++iter) { + int mid = (lo + hi + 1) >> 1; + if (x >= s_quantiles[mid]) { + lo = mid; + } else { + hi = mid - 1; + } + } + return static_cast(lo); +} + +// ---- Unified dispatcher ---- +// For modes 3/4/6/7, range_min and inv_range come from a per-block pre-pass. + +__device__ __forceinline__ uint8_t mapped_convert_to_uint8( + float x, + const TopKMappingParams& params, + const uint8_t* __restrict__ s_lut, + const float* __restrict__ s_quantiles, + float range_min, + float inv_range) +{ + switch (params.mode) { + case MAPPING_LUT_CDF: + if (params.lut != nullptr) return map_lut_cdf(x, s_lut); + return convert_to_uint8(x); // fallback to mode 0 when LUT not calibrated + case MAPPING_QUANTILE: + if (params.quantiles != nullptr) return map_quantile(x, s_quantiles); + return convert_to_uint8(x); // fallback to mode 0 when quantiles not calibrated + case MAPPING_POWER: + case MAPPING_LOG: + case MAPPING_ASINH: + case MAPPING_LOG1P: { + float val = apply_transform(x, params); + return linear_map_to_uint8(val, range_min, inv_range); + } + case MAPPING_TRUNC8: + return convert_to_uint8_bf16(x); + default: // MAPPING_NONE + return convert_to_uint8(x); + } +} + +// Helper: check if a mapping mode needs the auto-range pre-pass +__device__ __forceinline__ bool needs_auto_range(int mode) { + return (mode == MAPPING_POWER || mode == MAPPING_LOG || + mode == MAPPING_ASINH || mode == MAPPING_LOG1P); +} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 314f0fde..59592708 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -6,821 +6,1090 @@ * 2. optimize the performance a little * 3. fix the potential illegal memory access */ - #include - #include - #include - #include - #include - #include - #include - #include - #include - - #include - #include - #include - - namespace { - - constexpr int TopK = 2048; - constexpr int kThreadsPerBlock = 1024; - - #ifdef USE_ROCM - // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a - // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. - #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES - constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); - #else - constexpr size_t kSmem = 48 * 1024; // bytes - #endif - #else - // Reduced from 128KB to 32KB to improve occupancy. - // Each radix pass needs at most ~TopK candidates in the threshold bin, - // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. - constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) - #endif - - struct FastTopKParams { - const float* __restrict__ input; // [B, input_stride] - const int32_t* __restrict__ row_starts; // [B] - int32_t* __restrict__ indices; // [B, TopK] - int32_t* __restrict__ lengths; // [B] - int64_t input_stride; - }; - - // when length <= TopK, we can directly write the indices - __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { - const auto tid = threadIdx.x; - for (int i = tid; i < TopK; i += kThreadsPerBlock) { - indice[i] = (i < length) ? i : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform( - const float* __restrict__ score, - int32_t length, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - dst_page_table[i] = (i < length) ? src_page_table[i] : -1; - } - } - - // keep the first `length` entries, set others to -1 - __device__ void naive_topk_transform_ragged( - const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; - } - } - - __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return static_cast(key >> 8); - } - - __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); - } - - __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { - // An optimized topk kernel copied from tilelang kernel - // We assume length > TopK here, or it will crash - int topk = TopK; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; - - auto& s_histogram = s_histogram_buf[0]; - // allocate for two rounds - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // stage 1: 8bit coarse histogram - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(input[idx + row_start]); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = input[idx + row_start]; - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - /// NOTE: (dark) fuse the histogram computation here - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // stage 2: refine with 8bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - // clip here to prevent overflow - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = input[idx + row_start]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[TopK - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - /// NOTE: (dark) fuse the histogram computation here - s_input_idx[r_idx ^ 1][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // topk - void topk_kernel(const FastTopKParams params) { - const auto& [input, row_starts, indices, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto indice = indices + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_cuda(score, indice, length); - } else { - return fast_topk_cuda_tl(score, indice, row_start, length); - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // decode - void topk_transform_decode_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride) { - const auto& [input, _1, _2, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = 0; - const auto length = lengths[bid]; - const auto src_page_entry = src_page_table + bid * src_stride; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill - void topk_transform_prefill_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride, - const int32_t* __restrict__ cu_seqlens_q, - const int64_t prefill_bs) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto length = lengths[bid]; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - - /// NOTE: prefill bs is usually small, we can just use a simple loop here - /// We ensure that last cu_seqlens is equal to number of blocks launched - __shared__ const int32_t* s_src_page_entry; - if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { - if (tid < prefill_bs) { - if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { - s_src_page_entry = src_page_table + tid * src_stride; - } - } - } else { - for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { - if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { - s_src_page_entry = src_page_table + i * src_stride; - } - } - } - __syncthreads(); - const auto src_page_entry = s_src_page_entry; - - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } - } - - __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv - void topk_transform_prefill_ragged_kernel( - const FastTopKParams params, - int32_t* __restrict__ topk_indices_ragged, - const int32_t* __restrict__ topk_indices_offset) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto dst_indices_entry = topk_indices_ragged + bid * TopK; - const auto score = input + bid * input_stride; - const auto offset = topk_indices_offset[bid]; - - if (length <= TopK) { - return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_indices_entry[idx_0] = pos_0 + offset; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_indices_entry[idx_1] = pos_1 + offset; - } - } - - auto get_params( - const at::Tensor& score, - const at::Tensor& lengths, - std::optional row_starts_opt = std::nullopt, - std::optional indices_opt = std::nullopt) -> FastTopKParams { - const auto B = score.size(0); - TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); - if (row_starts_opt.has_value()) { - const auto& row_starts = row_starts_opt.value(); - TORCH_CHECK(row_starts.dim() == 1); - TORCH_CHECK(row_starts.size(0) == B); - } - TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); - TORCH_CHECK(lengths.size(0) == B); - int32_t* indices_data_ptr = nullptr; - if (indices_opt.has_value()) { - const auto& indices = indices_opt.value(); - TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); - TORCH_CHECK(indices.size(0) == B); - TORCH_CHECK(indices.size(1) == TopK); - indices_data_ptr = indices.data_ptr(); - } - - return FastTopKParams{ - .input = score.data_ptr(), - .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, - .indices = indices_data_ptr, - .lengths = lengths.data_ptr(), - .input_stride = score.stride(0), - }; - } - - template - void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { - #ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - #endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex integration: BOS/EOS-aware segmented TopK with index remapping - // ====================================================================== - - template - __device__ __forceinline__ float vortex_to_float(T x); - - template <> - __device__ __forceinline__ float vortex_to_float(float x) { return x; } - - template <> - __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { - return __bfloat162float(x); - } - - constexpr int VORTEX_MAX_TOPK = 2048; - - // Templated version of fast_topk_cuda_tl: - // - ScoreT: float or __nv_bfloat16 - // - target_k: runtime parameter (replaces compile-time TopK) - template - __device__ void fast_topk_vortex( - const ScoreT* __restrict__ input, - int* __restrict__ index, - int row_start, - int length, - int target_k) - { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; - - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // Stage 1: 8-bit coarse histogram - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); - ::atomicAdd(&vh_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { - #pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast( - convert_to_uint8(vortex_to_float(input[idx + row_start]))); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // Stage 2: refine with 8-bit radix passes - #pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } - } - - // Wrapper kernel: one CUDA block per batch*head segment - template - __global__ __launch_bounds__(kThreadsPerBlock) - void TopKOutput_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos) - { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } - } - - } // namespace - - #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - - void fast_topk_interface( - const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(indices); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - CHECK_CUDA(lengths); - const auto params = get_params(score, lengths, row_starts_opt, indices); - const auto B = score.size(0); - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - setup_kernel_smem_once(); - topk_kernel<<>>(params); - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& dst_page_table, - const at::Tensor& src_page_table, - const at::Tensor& cu_seqlens_q, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(dst_page_table); - CHECK_CUDA(src_page_table); - CHECK_CUDA(cu_seqlens_q); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); - TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); - TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); - const auto prefill_bs = cu_seqlens_q.size(0) - 1; - TORCH_CHECK(dst_page_table.size(0) == B); - TORCH_CHECK(dst_page_table.size(1) == TopK); - TORCH_CHECK(src_page_table.size(0) == prefill_bs); - TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - const auto src_stride = src_page_table.stride(0); - - // dispatch to decode or prefill - // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel - // decode: row_starts_opt is null, invokes the decode kernel - // target verify: row_starts_opt is null, invokes the prefill kernel - const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; - if (is_decode) { - setup_kernel_smem_once(); - topk_transform_decode_kernel<<>>( - params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); - } else { - setup_kernel_smem_once(); - topk_transform_prefill_kernel<<>>( - params, - dst_page_table.data_ptr(), - src_page_table.data_ptr(), - src_stride, - cu_seqlens_q.data_ptr(), - prefill_bs); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - void fast_topk_transform_ragged_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& topk_indices_ragged, - const at::Tensor& topk_indices_offset, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(topk_indices_ragged); - CHECK_CUDA(topk_indices_offset); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); - TORCH_CHECK(topk_indices_offset.dim() == 1); - - TORCH_CHECK(topk_indices_ragged.size(0) == B); - TORCH_CHECK(topk_indices_ragged.size(1) == TopK); - TORCH_CHECK(topk_indices_offset.size(0) == B); - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - - setup_kernel_smem_once(); - topk_transform_prefill_ragged_kernel<<>>( - params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); - } - - // ====================================================================== - // Vortex host entry point — same interface as topk_output in topk.cu - // ====================================================================== - void topk_output_sglang( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages) - { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); - } else { - TORCH_CHECK(false, - "topk_output: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); - } \ No newline at end of file +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +// Include mapping strategies (must come after convert_to_uint8 definition) +#include "topk_mapping.cuh" + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Vortex integration: BOS/EOS-aware segmented TopK with index remapping +// ====================================================================== + +template +__device__ __forceinline__ float vortex_to_float(T x); + +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } + +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +constexpr int VORTEX_MAX_TOPK = 2048; + +// Templated version of fast_topk_cuda_tl: +// - ScoreT: float or __nv_bfloat16 +// - target_k: runtime parameter (replaces compile-time TopK) +// - mapping: configurable value-remapping for Stage 1 bin assignment +template +__device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams& mapping) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + // Shared memory for mapping LUT / quantiles (loaded once per block) + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + // Auto-range for transform modes (3/4/6/7) + __shared__ float s_range_min, s_range_inv_range; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max of transformed values for linear bucketing + if (needs_auto_range(mapping.mode)) { + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + // Cross-warp reduction via shared memory + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Stage 1: 8-bit coarse histogram (with optional mapping) + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&vh_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper kernel: one CUDA block per batch*head segment +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// ====================================================================== +// Profiling histogram kernel: runs only Stage 1 and returns per-segment +// 256-bin histograms for distribution analysis +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKHistogram_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + int* __restrict__ histograms, // [eff_batch_size, 256] + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + constexpr auto RADIX = 256; + constexpr auto BLOCK_SIZE = kThreadsPerBlock; + __shared__ int s_histogram[RADIX]; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + __shared__ float s_range_min, s_range_inv_range; + + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + + const ScoreT* __restrict__ score_blk = score + start; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max for transform modes + if (needs_auto_range(mapping.mode)) { + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Initialize shared histogram + if (tx < RADIX) s_histogram[tx] = 0; + __syncthreads(); + + // Build histogram over the segment with mapping + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(score_blk[idx]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + // Write to global memory + int* __restrict__ out = histograms + bx * RADIX; + if (tx < RADIX) { + out[tx] = s_histogram[tx]; + } +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Vortex host entry point — same interface as topk_output in topk.cu +// ====================================================================== +void topk_output_sglang( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + // Build mapping params from optional tensors + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_output: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling: collect per-segment 256-bin histograms of Stage 1 bins +// ====================================================================== +void topk_profile_histogram( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + at::Tensor& histograms, + const int64_t eff_batch_size, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(histograms); + TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size + && histograms.size(1) == 256, + "histograms must be [eff_batch_size, 256]"); + TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, + "histograms must be int32"); + + // Build mapping params + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + TopKHistogram_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + TopKHistogram_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_histogram: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); +} + diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..d14650fd --- /dev/null +++ b/examples/README.md @@ -0,0 +1,399 @@ +# Vortex Torch Examples + +End-to-end accuracy evaluation and profiling pipelines for Vortex sparse attention on top of the SGLang inference engine. The scripts in this directory evaluate different TopK kernel variants, mapping functions, KV-cache quantization settings, and external sparse-attention backends on math reasoning benchmarks. + +--- + +## Mapping Functions Reference + +The TopK Stage-1 radix histogram uses 256 uint8 bins. A **mapping function** transforms raw attention scores before binning to improve bucket uniformity and reduce tail latency. Set via `--topk-mapping-mode`. + +| Mode | Name | Formula | Requires Calibration | Hyperparameter (`--topk-mapping-power`) | +|------|------|---------|---------------------|-----------------------------------------| +| 0 | None | FP16 bit-pattern bucketing | No | — | +| 1 | LUT CDF | `lut[original_bin]` (CDF equalization) | Yes (`--topk-mapping-lut-path`) | — | +| 2 | Quantile | Binary search over 256 float thresholds | Yes (`--topk-mapping-quantiles-path`) | — | +| 3 | Power | `sign(x) * \|x\|^p` | No | `p` (exponent, default 0.5) | +| 4 | Log | `sign(x) * log(\|x\| + 1)` | No | — | +| 5 | Index Cache | Reuse top-k indices from a preceding layer | No | — (see `--index-cache-shared-layers`) | +| 6 | Asinh | `asinh(beta * x)` | No | `beta` (default 0.5) | +| 7 | Log1p | `sign(x) * log1p(alpha * \|x\|)` | No | `alpha` (default 0.5) | +| 8 | Trunc8 | BF16 upper-8-bit bucketing | No | — | + +Modes 1 and 2 require an offline calibration step (see `calibrate_topk.py` in `benchmarks/`). Modes 3, 6, and 7 accept a tunable hyperparameter via `--topk-mapping-power`. + +--- + +## Python Scripts + +### `verify_algo.py` — End-to-End Accuracy Benchmark + +The primary evaluation script. Loads AMC 2023 math problems from `amc23.jsonl`, runs inference via the SGLang engine with Vortex sparse attention, and scores answers using `lighteval`'s extractive-match metric. Reports `mean@N`, `pass@N`, throughput, and memory access cost. + +**Usage:** + +```bash +python verify_algo.py [OPTIONS] +``` + +**CLI Arguments:** + +| Argument | Default | Description | +|----------|---------|-------------| +| `--trials` | 2 | Number of trials (each prompt repeated N times) | +| `--topk-val` | 30 | Number of top-k pages to select per segment | +| `--page-size` | 16 | Tokens per KV-cache page | +| `--vortex-module-name` | `gqa_block_sparse_attention` | Sparse attention algorithm module | +| `--model-name` | `Qwen/Qwen3-1.7B` | HuggingFace model identifier | +| `-f`, `--full-attention` | off | Disable sparse attention (full-attention baseline) | +| `--mem` | 0.8 | Static GPU memory fraction for SGLang | +| `--kv-cache-dtype` | `auto` | KV cache dtype: `auto`, `fp8_e5m2`, `fp8_e4m3`, `int8` | +| `--topk-type` | `naive` | TopK kernel: `naive` (CUB radix sort) or `sglang` (fast two-stage radix) | +| `--topk-mapping-mode` | 0 | Mapping function for Stage-1 binning (see table above) | +| `--topk-mapping-power` | 0.5 | Hyperparameter for modes 3/6/7 | +| `--topk-mapping-lut-path` | None | `.npy` uint8[256] LUT for mode 1 | +| `--topk-mapping-quantiles-path` | None | `.npy` float32[256] quantiles for mode 2 | +| `--index-cache-shared-layers` | None | Layer IDs that skip the indexer and reuse a previous layer's indices | + +**Fixed engine settings:** `attention_backend=flashinfer`, `vortex_max_seq_lens=12288`, layer 0 skipped, `reserved_bos=1`, `reserved_eos=2`. Sampling: `temperature=0.6`, `top_p=0.95`, `top_k=20`, `max_new_tokens=8192`. + +**Index cache note (mode 5):** When `--topk-mapping-mode 5` is set without `--index-cache-shared-layers`, the script defaults to even layers `[2, 4, 6, ..., 26]` and internally resets the mapping mode to 0 while passing the shared-layer list to the engine. + +**Example — full-attention baseline:** + +```bash +python verify_algo.py --full-attention --trials 8 --mem 0.7 +``` + +**Example — sglang TopK with power mapping:** + +```bash +python verify_algo.py \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power 0.25 \ + --trials 8 --topk-val 30 --mem 0.7 +``` + +**Example — sglang TopK with calibrated LUT:** + +```bash +python verify_algo.py \ + --topk-type sglang \ + --topk-mapping-mode 1 \ + --topk-mapping-lut-path calibration/lut.npy \ + --trials 8 --topk-val 30 --mem 0.7 +``` + +--- + +### `verify_aim24.py` — AIME 2024 Throughput Test (Legacy) + +A standalone throughput script that loads AIME 2024 from HuggingFace (`HuggingFaceH4/aime_2024`), builds chat prompts using the Qwen3 tokenizer with `enable_thinking=True`, and repeats each prompt 8 times. Outputs a JSONL file with generation results and timing metadata. Does **not** compute accuracy metrics. + +**Usage:** + +```bash +python verify_aim24.py +``` + +All settings are hard-coded (no CLI arguments): + +| Setting | Value | +|---------|-------| +| Model | `Qwen/Qwen3-0.6B` | +| Page size | 16 | +| Selected pages | 29 | +| Max sequence length | 20480 | +| Module | `block_sparse_attention` | +| Memory fraction | 0.9 | +| Max new tokens | 16384 | +| CUDA graph | Enabled | + +--- + +## Shell Scripts + +All shell scripts set `CUDA_VISIBLE_DEVICES` and save timestamped logs to `results/`. + +### `verify_algo.sh` — Baseline TopK Comparison (Naive vs SGLang) + +Runs `verify_algo.py` with `block_sparse_attention` comparing the `naive` and `sglang` TopK kernels. Each configuration is repeated `REPEAT_COUNT` times (default 3, overridable via environment variable). + +```bash +REPEAT_COUNT=5 bash verify_algo.sh +``` + +### `verify_algo_topk.sh` — Naive vs SGLang Comparison + +Similar to `verify_algo.sh` but simpler: runs `naive` TopK and `sglang` TopK back-to-back for `block_sparse_attention`, each with 8 trials. + +### `verify_algo_quant.sh` — INT8 KV-Cache Quantization + +Tests sparse attention with `--kv-cache-dtype int8` to measure accuracy under quantized KV caches. + +```bash +bash verify_algo_quant.sh +``` + +### `verify_sparse_backends.sh` — External Sparse Attention Backends + +Evaluates three external sparse-attention algorithms integrated via the Vortex flow interface: + +- `nsa` (Native Sparse Attention) +- `fsa` (Flash Sparse Attention) +- `flash_moba` (Flash MoBA) + +```bash +bash verify_sparse_backends.sh +``` + +### `verify_algo_topk_mapping.sh` — Full Mapping Mode Sweep + +Comprehensive sweep across all mapping modes: + +1. **Baseline:** `naive` TopK, mode 0 +2. **Calibration:** runs `calibrate_topk.py` to generate `lut.npy` and `quantiles.npy` (skipped if files exist) +3. **Mode 1** (LUT CDF) and **Mode 2** (Quantile) with calibrated tables +4. **Modes 0, 3, 4** (no calibration needed) — Power mode uses `--topk-mapping-power 0.5` +5. **Mode 6** (Asinh) — sweeps `beta` in `[0.5, 1.0, 2.0]` +6. **Mode 7** (Log1p) — sweeps `alpha` in `[0.5, 1.0, 2.0]` + +```bash +export CUDA_VISIBLE_DEVICES=0 +bash verify_algo_topk_mapping.sh +``` + +### `verify_algo_topk_mapping_new.sh` — Parametric Mapping Sweep (Modes 3, 6, 7) + +Focused hyperparameter sweep for the three parametric modes, preceded by an auto-tuning step: + +| Mode | Parameter | Sweep Values | +|------|-----------|-------------| +| 3 (Power) | `p` | 0.1, 0.25, 0.75, 0.9 | +| 6 (Asinh) | `beta` | 0.1, 0.5, 1.0, 2.0, 4.0 | +| 7 (Log1p) | `alpha` | 0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0 | + +Requires `calibration/raw_histograms.npy` for the auto-tune step. + +```bash +export CUDA_VISIBLE_DEVICES=5 +bash verify_algo_topk_mapping_new.sh +``` + +### `verify_algo_topk_mapping_indexcache.sh` — Index Cache (Mode 5) + +Tests the index-cache optimization where even-numbered layers `[2, 4, 6, ..., 26]` reuse top-k indices from the nearest preceding full layer, skipping their indexer entirely. + +```bash +bash verify_algo_topk_mapping_indexcache.sh +``` + +### `run_topk_benchmark.sh` — Unified TopK Benchmark Pipeline + +The most comprehensive benchmarking script. Three-step pipeline: + +1. **Calibrate** — collect real-data histograms + LUT/quantile tables +2. **Kernel bench** — latency + histogram profiling across batch sizes, sequence lengths, and distributions, followed by distribution analysis plots and auto-tuning +3. **E2E accuracy** — full-attention baseline plus every mapping mode + +```bash +bash run_topk_benchmark.sh --gpu 5 --trials 8 --model-name Qwen/Qwen3-1.7B +``` + +| Option | Default | Description | +|--------|---------|-------------| +| `--model-name` | `Qwen/Qwen3-1.7B` | HuggingFace model | +| `--topk-val` | 30 | Top-k pages | +| `--trials` | 8 | E2E trial count | +| `--mem` | 0.7 | GPU memory fraction | +| `--gpu` | 5 | CUDA device | +| `--algo` | `block_sparse_attention` | Sparse attention algorithm | +| `--skip-calibrate` | off | Reuse existing calibration | +| `--skip-kernel` | off | Skip kernel-level latency step | +| `--skip-e2e` | off | Skip E2E accuracy step | + +### `run_distribution_analysis.sh` — Bucket Distribution Profiling (All Modes) + +Three-step pipeline to analyze how each mapping mode affects the 256-bin bucket distribution: + +1. **Calibrate** — collect real-data histograms (skippable with `--real-histograms`) +2. **Bench** — histogram profiling with modes 0–8 on `bucket_uniform` and `normal` distributions +3. **Analyze** — generate comparison plots and CSV bucket count tables + +```bash +bash run_distribution_analysis.sh --gpu 5 +bash run_distribution_analysis.sh --gpu 5 --real-histograms /path/to/raw_histograms.npy +``` + +### `run_distribution_analysis_new.sh` — Bucket Distribution Profiling (Modes 3, 6, 7) + +Same pipeline as above but focused on parametric modes only, with an additional auto-tune step: + +1. **Calibrate** (or skip with existing histograms) +2. **Auto-tune** — sweep hyperparameters on synthetic data +3. **Bench** — histogram profiling for modes 3, 6, 7, 8 +4. **Analyze** — comparison plots + tables + +```bash +bash run_distribution_analysis_new.sh --gpu 5 +``` + +--- + +## Benchmarks Directory Scripts + +The `benchmarks/` directory contains standalone profiling and analysis tools used by the shell pipelines above. These can also be run independently. + +### `calibrate_topk.py` — Offline Calibration + +Runs the SGLang engine on real prompts from `amc23.jsonl` with histogram collection enabled. Produces three files: + +- `lut.npy` — uint8[256] CDF-equalized LUT for mode 1 +- `quantiles.npy` — float32[256] quantile breakpoints for mode 2 +- `raw_histograms.npy` — raw per-sample 256-bin histograms + +```bash +python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 --mem 0.7 \ + --output-dir calibration/ +``` + +### `bench_topk.py` — Kernel-Level Latency Benchmark + +Benchmarks `topk_output` (naive/CUB) and `topk_output_sglang` (fast radix) across configurable sweeps of batch size, sequence length, TopK value, KV heads, and score distributions. Optionally collects 256-bin histogram statistics. + +```bash +python benchmarks/bench_topk.py \ + --batch-sizes 4 8 16 \ + --seq-lens 2048 4096 8192 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --distributions normal lognormal uniform bucket_uniform \ + --histogram \ + --repeat 100 \ + --output-json results.json +``` + +### `autotune_topk_mapping.py` — Hyperparameter Auto-Tuning + +Sweeps hyperparameters for parametric mapping modes (3, 6, 7) using the `topk_profile_histogram` kernel on synthetic data. Ranks configurations by resolution rate, Gini coefficient, max/mean ratio, and nonzero bins. + +```bash +python benchmarks/autotune_topk_mapping.py \ + --topk-val 30 --batch-size 4 --seq-len 4096 --num-kv-heads 2 \ + --real-histograms calibration/raw_histograms.npy \ + --output-json autotune_results.json +``` + +### `analyze_topk_distribution.py` — Visualization and Analysis + +Loads profiling data and generates: +- Per-segment 256-bin bar charts +- Heatmaps (segments x bins, log-scale) +- Before/after LUT mapping comparisons +- Mode comparison grouped bar charts (Gini + max/mean) +- Distribution comparison plots across data sources +- CSV bucket count tables + +```bash +python benchmarks/analyze_topk_distribution.py \ + --bench-json bench_distribution.json \ + --real-histograms calibration/raw_histograms.npy \ + --output-dir plots/ +``` + +### `profile_topk_distribution.py` — Offline Table Generation + +Computes LUT and quantile tables from pre-collected histograms or raw scores without running a model. Outputs a single `.npz` archive. + +```bash +python benchmarks/profile_topk_distribution.py \ + --histograms-input raw_histograms.npy \ + --output mapping_tables.npz +``` + +### `greedy_layer_search.py` — Index Cache Layer Selection + +Greedy forward-selection of layers whose indexer can be skipped (index cache). Iteratively adds layers to the shared set as long as accuracy stays above `--threshold` times the baseline. + +```bash +cd examples && python ../benchmarks/greedy_layer_search.py \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 \ + --threshold 0.95 \ + --trials 1 \ + --num-layers 28 \ + --mem 0.7 +``` + +--- + +## Data Files + +| File | Description | +|------|-------------| +| `amc23.jsonl` | AMC 2023 math problems with `prompt` and `answer` fields, used by `verify_algo.py` and `calibrate_topk.py` | + +--- + +## Output Structure + +Results are saved under `results/` in timestamped directories: + +``` +results/ +├── dist_analysis_YYYYMMDD_HHMMSS/ +│ ├── step1_calibrate.log +│ ├── step2_autotune.log / step2_bench.log +│ ├── step3_bench.log / step3_analyze.log +│ ├── step4_analyze.log +│ ├── autotune_results.json +│ ├── bench_distribution.json +│ ├── distribution_comparison_*.png +│ ├── bucket_counts_*.csv +│ └── calibration/ +│ ├── lut.npy +│ ├── quantiles.npy +│ └── raw_histograms.npy +├── topk_benchmark_YYYYMMDD_HHMMSS/ +│ ├── kernel_latency.json +│ ├── e2e/ +│ │ ├── full_attention_baseline.log +│ │ ├── sglang_mode0_none.log +│ │ └── ... +│ └── calibration/ +└── *.log (individual run logs) +``` + +--- + +## Quick Start: Typical Workflow + +```bash +export CUDA_VISIBLE_DEVICES=0 + +# 1. Calibrate to generate LUT + quantile tables +python benchmarks/calibrate_topk.py \ + --model-name Qwen/Qwen3-1.7B --topk-val 30 --mem 0.7 \ + --output-dir examples/calibration/ + +# 2. Run full-attention baseline +python examples/verify_algo.py --full-attention --trials 8 --mem 0.7 + +# 3. Evaluate sparse attention with different mapping modes +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 0 --trials 8 --mem 0.7 + +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 3 --topk-mapping-power 0.25 \ + --trials 8 --mem 0.7 + +python examples/verify_algo.py \ + --topk-type sglang --topk-mapping-mode 6 --topk-mapping-power 1.0 \ + --trials 8 --mem 0.7 + +# 4. Or run the full pipeline in one shot +bash examples/run_topk_benchmark.sh --gpu 0 --trials 8 +``` diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh new file mode 100755 index 00000000..287c4545 --- /dev/null +++ b/examples/run_distribution_analysis.sh @@ -0,0 +1,141 @@ +#!/usr/bin/env bash +# ============================================================ +# Bucket Distribution Profiling Pipeline +# +# Profiles the SGLang TopK kernel's first-pass bucket distribution +# to identify hotspot buckets causing tail latency. +# +# Three steps: +# 1. Calibrate — collect real-data histograms +# (skippable via --real-histograms PATH) +# 2. Bench — histogram profiling (bucket_uniform + normal) +# 3. Analyze — comparison plots + bucket count tables +# +# All outputs (JSON, plots, CSV tables, logs) are written to a +# single timestamped folder under examples/results/dist_analysis_*. +# +# Usage: +# bash run_distribution_analysis.sh --gpu 5 +# bash run_distribution_analysis.sh --gpu 5 \ +# --real-histograms /path/to/calibration_dir/raw_histograms.npy +# ============================================================ + +# Mapping functions: +# 0: None — original fp16 bit-pattern bucketing +# 1: LUT CDF — LUT-based CDF equalization (calibrated) +# 2: Quantile — piecewise-linear quantile mapping (calibrated) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 5: Index Cache — reuse previous layer's indices +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 8: Trunc8 — bf16 upper-8-bit bucketing + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +MEM=0.7 +ALGO="block_sparse_attention" +# The path to the raw_histograms.npy file (set to skip calibration) +# REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Bucket Distribution Profiling Pipeline" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " GPU: ${GPU_ID}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating — collecting real-inference histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# ── Step 2: Histogram profiling (bucket_uniform + normal) ───── +echo "" +echo ">>> Step 2: Kernel-level histogram profiling (bucket_uniform + normal)" + +BENCH_JSON="${RUN_DIR}/bench_distribution.json" + +python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 \ + --seq-lens 4096 \ + --topk-vals "${TOPK_VAL}" \ + --num-kv-heads 2 \ + --distributions bucket_uniform normal \ + --histogram \ + --filter-kernels sglang_m0 sglang_m1 sglang_m2 sglang_m3 sglang_m4 sglang_m6 sglang_m7 sglang_m8 \ + --repeat 20 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_bench.log" + +echo ">>> Step 2: Done. Results saved to ${BENCH_JSON}" + +# ── Step 3: Analyze — comparison plots + tables ─────────────── +echo "" +echo ">>> Step 3: Generating distribution comparison plots + tables" + +python "${BENCH_DIR}/analyze_topk_distribution.py" \ + --bench-json "${BENCH_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --output-dir "${RUN_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step3_analyze.log" + +echo ">>> Step 3: Done." + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Bucket Distribution Profiling Complete" +echo " All outputs in: ${RUN_DIR}/" +echo " bench_distribution.json — raw benchmark data" +echo " distribution_comparison.png — bucket dist plots" +echo " bucket_counts.csv — per-bucket count table" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh new file mode 100755 index 00000000..3dc1bd41 --- /dev/null +++ b/examples/run_distribution_analysis_new.sh @@ -0,0 +1,150 @@ +#!/usr/bin/env bash +# ============================================================ +# Bucket Distribution Profiling Pipeline (modes 3, 6, 7 only) +# +# Tests only the parametric mapping modes with auto-tuning: +# Mode 3 (Power): y = sign(x) * |x|^p +# Mode 6 (Asinh): y = asinh(beta * x) +# Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) +# Mode 8 (Trunc8): bf16 upper-8-bit bucketing +# +# Four steps: +# 1. Calibrate — collect real-data histograms +# (skippable via --real-histograms PATH) +# 2. Auto-tune — sweep hyperparameters on synthetic data +# 3. Bench — histogram profiling (bucket_uniform + normal) +# 4. Analyze — comparison plots + bucket count tables +# +# Usage: +# bash run_distribution_analysis_new.sh --gpu 5 +# bash run_distribution_analysis_new.sh --gpu 5 \ +# --real-histograms /path/to/calibration_dir/raw_histograms.npy +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +MEM=0.7 +ALGO="block_sparse_attention" +# The path to the raw_histograms.npy file (set to skip calibration) +# REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="${SCRIPT_DIR}/calibration/raw_histograms.npy" +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Bucket Distribution Profiling (modes 3, 6, 7)" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " GPU: ${GPU_ID}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating — collecting real-inference histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# ── Step 2: Auto-tune — sweep hyperparameters on synthetic data ───── +echo "" +echo ">>> Step 2: Auto-tuning hyperparameters (modes 3, 6, 7)" + +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len 4096 \ + --num-kv-heads 2 \ + --real-histograms "${REAL_HIST_PATH}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + +echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + +# ── Step 3: Histogram profiling (bucket_uniform + normal) ───── +echo "" +echo ">>> Step 3: Kernel-level histogram profiling (modes 3, 6, 7)" + +BENCH_JSON="${RUN_DIR}/bench_distribution.json" + +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 \ + --seq-lens 4096 \ + --topk-vals "${TOPK_VAL}" \ + --num-kv-heads 2 \ + --distributions bucket_uniform normal \ + --histogram \ + --real-histograms "${REAL_HIST_PATH}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --filter-kernels sglang_m3 sglang_m6 sglang_m7 sglang_m8 \ + --repeat 20 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_bench.log" + +echo ">>> Step 3: Done. Results saved to ${BENCH_JSON}" + +# ── Step 4: Analyze — comparison plots + tables ─────────────── +echo "" +echo ">>> Step 4: Generating distribution comparison plots + tables" + +python "${BENCH_DIR}/analyze_topk_distribution.py" \ + --bench-json "${BENCH_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --output-dir "${RUN_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step4_analyze.log" + +echo ">>> Step 4: Done." + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Bucket Distribution Profiling Complete (modes 3, 6, 7)" +echo " All outputs in: ${RUN_DIR}/" +echo " autotune_results.json — hyperparameter sweep rankings" +echo " bench_distribution.json — raw benchmark data" +echo " distribution_comparison.png — bucket dist plots" +echo " bucket_counts.csv — per-bucket count table" +echo " step{1,2,3,4}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh new file mode 100755 index 00000000..5a7ed94e --- /dev/null +++ b/examples/run_topk_benchmark.sh @@ -0,0 +1,294 @@ +#!/usr/bin/env bash +# ============================================================ +# TopK Benchmark +# +# Compares ALL TopK kernel variants under controlled conditions: +# Step 1: Calibrate (for modes 1/2) +# Step 2: Kernel-level latency (bench_topk.py, all 6 modes) +# Step 3: E2E accuracy (verify_algo.py) +# - Full-attention baseline first +# - Then naive, sglang mode 0/1/2/3/4 +# - Same model, same prompts, deterministic sampling +# +# Fairness improvements over verify_algo_topk_mapping.sh: +# - Full-attention baseline for absolute reference +# - All modes in one sweep (including calibrated 1/2) +# - Sequential runs on same CUDA device minimize interference +# - Deterministic sampling (temperature=0) for reproducibility +# - Results saved to a single timestamped directory +# +# Usage: +# bash run_topk_benchmark.sh [OPTIONS] +# +# Options: +# --model-name NAME HuggingFace model (default: Qwen/Qwen3-1.7B) +# --topk-val K Top-k value (default: 30) +# --trials N E2E trial count (default: 8) +# --mem FRAC GPU memory fraction (default: 0.7) +# --gpu GPU_ID CUDA device (default: 0) +# --algo NAME Sparse attention algorithm (default: block_sparse_attention) +# --skip-calibrate Reuse existing calibration data +# --skip-kernel Skip kernel-level benchmark (step 2) +# --skip-e2e Skip E2E accuracy benchmark (step 3) +# ============================================================ +set -euo pipefail + +# use GPU_ID to set the GPU id you want to use +GPU_ID=5 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +TRIALS=8 +MEM=0.7 +ALGO="block_sparse_attention" +SKIP_CALIBRATE=false +SKIP_KERNEL=false +SKIP_E2E=true + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --trials) TRIALS="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --skip-calibrate) SKIP_CALIBRATE=true; shift ;; + --skip-kernel) SKIP_KERNEL=true; shift ;; + --skip-e2e) SKIP_E2E=true; shift ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +RUN_DIR="${RESULTS_DIR}/topk_benchmark_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Fair Unified TopK Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Trials: ${TRIALS}" +echo " GPU: ${GPU_ID}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate (for modes 1/2) ──────────────────────── +CALIBRATION_DIR="${RUN_DIR}/calibration" +if [ "${SKIP_CALIBRATE}" = true ] && [ -d "${CALIBRATION_DIR}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (--skip-calibrate)" +else + echo "" + echo ">>> Step 1: Calibrating — collecting histograms for LUT/quantile modes" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + echo ">>> Step 1: Done." +fi + +# ── Step 2: Kernel-level latency benchmark ──────────────────── +if [ "${SKIP_KERNEL}" = true ]; then + echo "" + echo ">>> Step 2: SKIPPED (--skip-kernel)" +else + # Step 2a: Auto-tune parametric mapping modes (must run before bench) + echo "" + echo ">>> Step 2a: Auto-tuning parametric mapping hyperparameters" + AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + REAL_HIST_ARGS="" + if [ -f "${CALIBRATION_DIR}/raw_histograms.npy" ]; then + REAL_HIST_ARGS="--real-histograms ${CALIBRATION_DIR}/raw_histograms.npy" + fi + python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len 4096 \ + --num-kv-heads 2 \ + ${REAL_HIST_ARGS} \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2a_autotune.log" + echo ">>> Step 2a: Done. Autotune results saved to ${AUTOTUNE_JSON}" + + # Step 2b: Kernel-level latency + histogram benchmark (using autotune params) + echo "" + echo ">>> Step 2b: Kernel-level latency benchmark (all modes)" + + BENCH_JSON="${RUN_DIR}/kernel_latency.json" + + # Build calibration args + LUT_ARGS="" + if [ -f "${CALIBRATION_DIR}/lut.npy" ]; then + LUT_ARGS="--lut-path ${CALIBRATION_DIR}/lut.npy" + fi + QUANTILES_ARGS="" + if [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then + QUANTILES_ARGS="--quantiles-path ${CALIBRATION_DIR}/quantiles.npy" + fi + + python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 8 16 32 \ + --seq-lens 2048 4096 8192 16384 \ + --topk-vals "${TOPK_VAL}" \ + --num-kv-heads 2 4 \ + --distributions normal lognormal uniform \ + --histogram \ + --hit-rate \ + --warmup 20 \ + --repeat 100 \ + ${LUT_ARGS} \ + ${QUANTILES_ARGS} \ + --autotune-json "${AUTOTUNE_JSON}" \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2b_kernel_bench.log" + + echo ">>> Step 2b: Done. Results saved to ${BENCH_JSON}" + + # Step 2c: Per-mode distribution analysis + echo "" + echo ">>> Step 2c: Generating per-mode distribution analysis" + + python "${BENCH_DIR}/analyze_topk_distribution.py" \ + --bench-json "${BENCH_JSON}" \ + ${REAL_HIST_ARGS} \ + --output-dir "${RUN_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step2c_analyze.log" + + echo ">>> Step 2c: Done. Per-mode plots saved to ${RUN_DIR}" +fi + +# ── Step 3: E2E accuracy comparison ────────────────────────── +if [ "${SKIP_E2E}" = true ]; then + echo "" + echo ">>> Step 3: SKIPPED (--skip-e2e)" +else + echo "" + echo ">>> Step 3: E2E accuracy comparison" + + E2E_DIR="${RUN_DIR}/e2e" + mkdir -p "${E2E_DIR}" + + # Helper: run verify_algo.py with common args and save output + run_e2e() { + local label="$1" + shift + local logfile="${E2E_DIR}/${label}.log" + echo "" + echo " --- ${label} ---" + { time python "${SCRIPT_DIR}/verify_algo.py" \ + --trials "${TRIALS}" \ + --topk-val "${TOPK_VAL}" \ + --model-name "${MODEL_NAME}" \ + --mem "${MEM}" \ + "$@" ; } \ + 2>&1 | tee "${logfile}" + } + + # 3a. Full-attention baseline (oracle) + run_e2e "full_attention_baseline" \ + --full-attention + + # 3b. Naive TopK + run_e2e "naive_mode0" \ + --vortex-module-name "${ALGO}" \ + --topk-type naive + + # 3c. SGLang mode 0 (no mapping) + run_e2e "sglang_mode0_none" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 0 + + # 3d. SGLang mode 1 (LUT CDF) — requires calibration + if [ -f "${CALIBRATION_DIR}/lut.npy" ]; then + run_e2e "sglang_mode1_lut_cdf" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 1 \ + --topk-mapping-lut-path "${CALIBRATION_DIR}/lut.npy" + else + echo " --- sglang_mode1_lut_cdf: SKIPPED (no lut.npy) ---" + fi + + # 3e. SGLang mode 2 (quantile) — requires calibration + if [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then + run_e2e "sglang_mode2_quantile" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 2 \ + --topk-mapping-quantiles-path "${CALIBRATION_DIR}/quantiles.npy" + else + echo " --- sglang_mode2_quantile: SKIPPED (no quantiles.npy) ---" + fi + + # 3f. SGLang mode 3 (power) + run_e2e "sglang_mode3_power" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power 0.5 + + # 3g. SGLang mode 4 (log) + run_e2e "sglang_mode4_log" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 4 + + # 3h. SGLang mode 6 (asinh) + run_e2e "sglang_mode6_asinh" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power 1.0 + + # 3i. SGLang mode 7 (log1p) + run_e2e "sglang_mode7_log1p" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power 1.0 + + echo "" + echo ">>> Step 3: Done. E2E logs saved to ${E2E_DIR}/" + + # ── Summary table: extract pass@N from each log ───────────── + echo "" + echo "============================================================" + echo "E2E Accuracy Summary" + echo "============================================================" + printf "%-35s %s\n" "Configuration" "Result" + printf "%-35s %s\n" "-----------------------------------" "------" + for logfile in "${E2E_DIR}"/*.log; do + label=$(basename "${logfile}" .log) + # Extract the last line matching pass@ pattern + result=$(grep -oP 'pass@\d+\s*[=:]\s*[\d.]+' "${logfile}" | tail -1 || echo "N/A") + printf "%-35s %s\n" "${label}" "${result}" + done + echo "============================================================" +fi + +# ── Final Summary ───────────────────────────────────────────── +echo "" +echo "============================================================" +echo "TopK Benchmark Complete" +echo " All results: ${RUN_DIR}" +echo " Calibration: ${CALIBRATION_DIR}" +[ "${SKIP_KERNEL}" != true ] && echo " Kernel JSON: ${RUN_DIR}/kernel_latency.json" +[ "${SKIP_KERNEL}" != true ] && echo " Per-mode: ${RUN_DIR}/distribution_comparison_m*.png, bucket_counts_m*.csv" +[ "${SKIP_E2E}" != true ] && echo " E2E logs: ${RUN_DIR}/e2e/" +echo "============================================================" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 91f92e76..e04f787c 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -11,7 +11,11 @@ from lighteval.models.model_output import ModelResponse from datasets import load_dataset, Dataset, concatenate_datasets import argparse +import ast import json +import os +import subprocess +import sys MATH_QUERY_TEMPLATE = """ Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. @@ -57,10 +61,16 @@ def verify_algos( mem: float = 0.8, kv_cache_dtype: str = "auto", topk_type: str = "naive", -): +topk_mapping_mode: int = 0, +topk_mapping_power: float = 0.5, +topk_mapping_lut_path: str = None, +topk_mapping_quantiles_path: str = None, +index_cache_shared_layers: list = None, +disable_cuda_graph: bool = False, +): llm = sgl.Engine(model_path=model_name, - disable_cuda_graph=False, + disable_cuda_graph=disable_cuda_graph, page_size=page_size, vortex_topk_val=topk_val, disable_overlap_schedule=True, @@ -74,16 +84,20 @@ def verify_algos( mem_fraction_static=mem, kv_cache_dtype=kv_cache_dtype, vortex_topk_type=topk_type, + vortex_topk_mapping_mode=topk_mapping_mode, + vortex_topk_mapping_power=topk_mapping_power, + vortex_topk_mapping_lut_path=topk_mapping_lut_path, + vortex_topk_mapping_quantiles_path=topk_mapping_quantiles_path, + vortex_index_cache_shared_layers=index_cache_shared_layers, ) - with open("amc23.jsonl", "r", encoding="utf-8") as f: requests = [json.loads(line) for line in f] - + requests = requests * trials prompts = [req["prompt"] for req in requests] sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 8192} - + o = llm.generate(prompts, sampling_params) gold_metric = MultilingualExtractiveMatchMetric( language=Language.ENGLISH, @@ -93,7 +107,7 @@ def verify_algos( pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), aggregation_function=max, ) - + results = [] for data, item in zip(requests, o): golds = [data["answer"]] @@ -103,7 +117,7 @@ def verify_algos( result = gold_metric.compute(model_response=ModelResponse(text=[predictions]), doc=target) except: result = 0.0 - + results.append( { "score": float(result), @@ -122,7 +136,7 @@ def verify_algos( # print(f" question: {data['question'][:120]}...") # print(f" prediction: {predictions[:200]}...") # print() - + total_accuracy = 0.0 total_tokens = 0 @@ -236,11 +250,54 @@ def parse_args(): choices=["naive", "sglang"], help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").', ) + parser.add_argument( + "--topk-mapping-mode", + type=int, + default=0, + choices=[0, 1, 2, 3, 4, 5, 6, 7], + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p (default: 0).', + ) + + parser.add_argument( + "--topk-mapping-power", + type=float, + default=0.5, + help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 7 asinh), alpha (mode 8 log1p). Default: 0.5.', + ) + + parser.add_argument( + "--topk-mapping-lut-path", + type=str, + default=None, + help="Path to .npy file with uint8[256] LUT for topk mapping mode 1.", + ) + + parser.add_argument( + "--topk-mapping-quantiles-path", + type=str, + default=None, + help="Path to .npy file with float32[256] quantiles for topk mapping mode 2.", + ) + + parser.add_argument( + "--index-cache-shared-layers", + type=int, + nargs="+", + default=None, + help="Layer IDs that reuse indices from the nearest preceding full layer (skip indexer).", + ) + return parser.parse_args() if __name__ == "__main__": args = parse_args() + # --- Mode 5: Index Cache (default even-layer pattern) --- + if args.topk_mapping_mode == 5: + if args.index_cache_shared_layers is None: + args.index_cache_shared_layers = list(range(2, 28, 2)) # [2,4,6,...,26] + args.topk_mapping_mode = 0 + summary = verify_algos( trials=args.trials, topk_val=args.topk_val, @@ -251,6 +308,11 @@ def parse_args(): mem=args.mem, kv_cache_dtype=args.kv_cache_dtype, topk_type=args.topk_type, + topk_mapping_mode=args.topk_mapping_mode, + topk_mapping_power=args.topk_mapping_power, + topk_mapping_lut_path=args.topk_mapping_lut_path, + topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, + index_cache_shared_layers=args.index_cache_shared_layers, ) print(summary) diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index aa01fe66..3edf9b62 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -1,6 +1,7 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=6 +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 sparse_algos=( "block_sparse_attention" @@ -22,4 +23,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" - done \ No newline at end of file + done diff --git a/examples/verify_algo_quant.sh b/examples/verify_algo_quant.sh index a7601de9..a2663e97 100644 --- a/examples/verify_algo_quant.sh +++ b/examples/verify_algo_quant.sh @@ -1,6 +1,7 @@ #!/usr/bin/env bash set -e -export CUDA_VISIBLE_DEVICES=6 +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 sparse_algos=( "block_sparse_attention" @@ -23,19 +24,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" - done - - for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/${algo}_fp8_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --kv-cache-dtype fp8_e4m3" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --kv-cache-dtype fp8_e4m3 \ - --topk-type naive \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh new file mode 100644 index 00000000..918252cd --- /dev/null +++ b/examples/verify_algo_topk_mapping.sh @@ -0,0 +1,175 @@ +#!/usr/bin/env bash +set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +# Mapping functions: +# 0: None — original fp16 bit-pattern bucketing +# 1: LUT CDF — LUT-based CDF equalization (calibrated) +# 2: Quantile — piecewise-linear quantile mapping (calibrated) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 5: Index Cache — reuse previous layer's indices +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 8: Trunc8 — bf16 upper-8-bit bucketing +export CUDA_VISIBLE_DEVICES=0 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +sparse_algos=( + "block_sparse_attention" +) + +topk_mapping_modes=( + 0 # none + 3 # power + 4 # log +) +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +# Set this to an existing calibration directory to skip re-running calibration. +# It must contain lut.npy and quantiles.npy (output of calibrate_topk.py). +CALIBRATION_DIR="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration" + +# ============================================================ +# Baseline: naive topk (mode 0) +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_naive_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type naive --topk-mapping-mode 0" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --topk-mapping-mode 0 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Calibration: collect histograms for LUT/quantile generation +# Skipped if CALIBRATION_DIR already has lut.npy + quantiles.npy +# ============================================================ +if [ -f "${CALIBRATION_DIR}/lut.npy" ] && [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then + echo ">>> Calibration SKIPPED (using existing ${CALIBRATION_DIR})" +else + CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + for algo in "${sparse_algos[@]}"; do + echo ">>> Calibrating for ${algo}..." + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-val 30 \ + --mem 0.7 \ + --vortex-module-name "${algo}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibration_${algo}_${TIMESTAMP}.log" + done +fi + + + +# ============================================================ +# Mode 1: LUT CDF with calibrated LUT +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_1_calibrated_${TIMESTAMP}.log" + echo ">>> Running mode 1 (LUT CDF) with calibrated LUT for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 1 \ + --topk-mapping-lut-path "${CALIBRATION_DIR}/lut.npy" \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Mode 2: Quantile with calibrated quantiles +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_2_calibrated_${TIMESTAMP}.log" + echo ">>> Running mode 2 (quantile) with calibrated quantiles for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 2 \ + --topk-mapping-quantiles-path "${CALIBRATION_DIR}/quantiles.npy" \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# sglang topk: modes that don't need calibration (0, 3, 4) +# ============================================================ +for algo in "${sparse_algos[@]}"; do + for topk_mapping_mode in "${topk_mapping_modes[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_${topk_mapping_mode}_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode ${topk_mapping_mode}" + echo ">>> Saving results to ${OUTFILE}" + + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode ${topk_mapping_mode} \ + --topk-mapping-power 0.5 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Mode 6: asinh — sweep beta values +# ============================================================ +for algo in "${sparse_algos[@]}"; do + for beta in 0.5 1.0 2.0; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${beta}_${TIMESTAMP}.log" + echo ">>> Running mode 6 (asinh) beta=${beta} for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power ${beta} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Mode 7: log1p — sweep alpha values +# ============================================================ +for algo in "${sparse_algos[@]}"; do + for alpha in 0.5 1.0 2.0; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${alpha}_${TIMESTAMP}.log" + echo ">>> Running mode 7 (log1p) alpha=${alpha} for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power ${alpha} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_indexcache.sh b/examples/verify_algo_topk_mapping_indexcache.sh new file mode 100644 index 00000000..9002084c --- /dev/null +++ b/examples/verify_algo_topk_mapping_indexcache.sh @@ -0,0 +1,45 @@ +#!/usr/bin/env bash +set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +export CUDA_VISIBLE_DEVICES=5 + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +sparse_algos=( + "block_sparse_attention" +) + +# --- Mode 5: Index Cache (default even-layer pattern) --- +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_mode5_index_cache_${TIMESTAMP}.log" + echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode 5 (index cache)" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 5 \ + --index-cache-shared-layers 2 4 6 8 10 12 14 16 18 20 22 24 26 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# --- Mode 6: Greedy layer selection --- +# for algo in "${sparse_algos[@]}"; do +# OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_mode6_greedy_${TIMESTAMP}.log" +# echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode 6 (greedy)" +# echo ">>> Saving results to ${OUTFILE}" +# { time python verify_algo.py \ +# --trials 8 \ +# --topk-val 30 \ +# --vortex-module-name "${algo}" \ +# --model-name Qwen/Qwen3-1.7B \ +# --topk-type sglang \ +# --topk-mapping-mode 6 \ +# --mem 0.7 ; } \ +# 2>&1 | tee "${OUTFILE}" +#done diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh new file mode 100644 index 00000000..b701be28 --- /dev/null +++ b/examples/verify_algo_topk_mapping_new.sh @@ -0,0 +1,128 @@ +#!/usr/bin/env bash +set -e +# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use +# Mapping functions: +# 0: None — original fp16 bit-pattern bucketing +# 1: LUT CDF — LUT-based CDF equalization (calibrated) +# 2: Quantile — piecewise-linear quantile mapping (calibrated) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 5: Index Cache — reuse previous layer's indices +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 8: Trunc8 — bf16 upper-8-bit bucketing +export CUDA_VISIBLE_DEVICES=5 + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +sparse_algos=( + "block_sparse_attention" +) + +# Path to real-data histograms from calibration (for auto-tuning) +REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" + +RESULTS_DIR="results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) + +# ============================================================ +# Step 0: Auto-tune — find best hyperparameters per mode +# Uses topk_profile_histogram kernel on synthetic data (fast, no model) +# ============================================================ +echo "============================================================" +echo "Step 0: Auto-tuning hyperparameters (synthetic data)" +echo "============================================================" +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val 30 \ + --batch-size 4 \ + --seq-len 4096 \ + --num-kv-heads 2 \ + --real-histograms "${REAL_HISTOGRAMS}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" +echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" +echo "" + +# ============================================================ +# Step 1: Mode 3 (power) — sweep p values +# ============================================================ +echo "============================================================" +echo "Step 1: Mode 3 (power) — sweeping p" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + for p in 0.1 0.25 0.75 0.9; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${p}_${TIMESTAMP}.log" + echo ">>> Mode 3 (power) p=${p} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power ${p} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Step 2: Mode 6 (asinh) — sweep beta values +# ============================================================ +echo "============================================================" +echo "Step 2: Mode 6 (asinh) — sweeping beta" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + for beta in 0.1 0.5 1.0 2.0 4.0; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${beta}_${TIMESTAMP}.log" + echo ">>> Mode 6 (asinh) beta=${beta} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power ${beta} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Step 3: Mode 7 (log1p) — sweep alpha values +# ============================================================ +echo "============================================================" +echo "Step 3: Mode 7 (log1p) — sweeping alpha" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + for alpha in 0.1 0.5 0.75 1.0 2.0 4.0 8.0; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${alpha}_${TIMESTAMP}.log" + echo ">>> Mode 7 (log1p) alpha=${alpha} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power ${alpha} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" + done +done + +# ============================================================ +# Summary +# ============================================================ +echo "" +echo "============================================================" +echo "All sweeps complete. Results in ${RESULTS_DIR}/" +echo " Auto-tune: ${AUTOTUNE_JSON}" +echo " Mode 3 (power): p = [0.1, 0.25, 0.75, 0.9]" +echo " Mode 6 (asinh): beta = [0.1, 0.5, 1.0, 2.0, 4.0]" +echo " Mode 7 (log1p): alpha = [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]" +echo "============================================================" diff --git a/third_party/sglang b/third_party/sglang index 20e4c29d..5f51c8ef 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 20e4c29d206046d6b4eb3b57cc26fd20bf9c519b +Subproject commit 5f51c8ef485fb45990c8166f439da2ee695c03c1 diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index d6da9c1a..78e2923a 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -1,5 +1,6 @@ from __future__ import annotations from typing import Any, Final, Union +import numpy as np import torch from ..abs import ContextBase from ..utils import UNSET, Mode @@ -23,6 +24,8 @@ class Context(ContextBase): "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", + "topk_mapping_mode", "topk_mapping_power", "topk_mapping_lut", "topk_mapping_quantiles", + "topk_histogram_enabled", # auxilary memory in graph "_aux_total_bytes", @@ -69,6 +72,11 @@ class Context(ContextBase): page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). topk_type: str #: TopK kernel type: "naive" or "sglang". + topk_mapping_mode: int #: TopK mapping mode (0=none, 1=lut, 2=quantile, 3=power, 4=log). + topk_mapping_power: float #: Power exponent for mapping mode 3. + topk_mapping_lut: object #: Optional uint8[256] LUT tensor for mapping mode 1. + topk_mapping_quantiles: object #: Optional float32[256] quantiles tensor for mapping mode 2. + topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -146,12 +154,26 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.page_reserved_bos = sa.vortex_page_reserved_bos self.page_reserved_eos = sa.vortex_page_reserved_eos self.topk_type = getattr(sa, "vortex_topk_type", "naive") + self.topk_mapping_mode = getattr(sa, "vortex_topk_mapping_mode", 0) + self.topk_mapping_power = getattr(sa, "vortex_topk_mapping_power", 0.5) + self.topk_histogram_enabled = getattr(sa, "vortex_topk_histogram", False) + + device = getattr(model_runner, "device", "cpu") + + # Load calibration data from .npy files when paths are provided + lut_path = getattr(sa, 'vortex_topk_mapping_lut_path', None) + if lut_path is not None: + lut_np = np.load(lut_path).astype(np.uint8) + self.topk_mapping_lut = torch.from_numpy(lut_np).to(device) + + quantiles_path = getattr(sa, 'vortex_topk_mapping_quantiles_path', None) + if quantiles_path is not None: + q_np = np.load(quantiles_path).astype(np.float32) + self.topk_mapping_quantiles = torch.from_numpy(q_np).to(device) self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads ) - - device = getattr(model_runner, "device", "cpu") self.winfo_q_indices = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) self.winfo_kv_offsets = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) self.winfo_kv_lens = torch.zeros((self.max_num_workloads,), dtype=torch.int32, device=device) diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 8859d61a..e4208dc8 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,9 +1,21 @@ import torch -from typing import Dict, Callable, Optional +from typing import Dict, Callable, List, Optional from ..abs import vOp -from vortex_torch_C import topk_output, topk_output_sglang +from vortex_torch_C import topk_output, topk_output_sglang, topk_profile_histogram from .context import Context from ..abs import vTensor, FORMAT +from ..utils import UNSET + +# --- Module-level histogram accumulator for offline calibration --- +_calibration_histograms: List[torch.Tensor] = [] + +def get_calibration_histograms() -> List[torch.Tensor]: + """Return collected histogram tensors (each [eff_bs, 256] int32 on CPU).""" + return _calibration_histograms + +def clear_calibration_histograms() -> None: + """Clear all collected calibration histograms.""" + _calibration_histograms.clear() class topK(vOp): r""" @@ -86,6 +98,7 @@ def __init__(self): super().__init__() self.impl: Optional[Callable] = None self.topk_type: str = "naive" + self.last_histograms: Optional[torch.Tensor] = None # ---------------- profile ---------------- def profile(self, x: vTensor, o: vTensor, ctx: Context) -> None: @@ -232,6 +245,15 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso if self.topk_type == "sglang": # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) + mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) + mapping_power = getattr(ctx, 'topk_mapping_power', 0.5) + mapping_lut = getattr(ctx, 'topk_mapping_lut', None) + mapping_quantiles = getattr(ctx, 'topk_mapping_quantiles', None) + # UNSET sentinel is not a valid torch.Tensor — coerce to None + if mapping_lut is UNSET: + mapping_lut = None + if mapping_quantiles is UNSET: + mapping_quantiles = None self.impl( x, ctx.dense_kv_indptr, @@ -243,14 +265,18 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_bos, ctx.page_reserved_eos, ctx.max_num_pages_per_request, + mapping_mode, + mapping_power, + mapping_lut, + mapping_quantiles, ) else: - # topk_output (naive): (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) + # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) self.impl( x, ctx.dense_kv_indptr, - ctx.sparse_kv_indptr, ctx.dense_kv_indices, + ctx.sparse_kv_indptr, o, ctx.batch_size * ctx.num_kv_heads, ctx.topk_val, @@ -258,4 +284,30 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_eos, ctx.max_num_pages_per_request, ) + + # Optional histogram profiling (default disabled, no overhead when off). + # Skip entirely during CUDA graph capture — allocations and D2H copies + # are not permitted while a stream is being captured. + if ( + getattr(ctx, 'topk_histogram_enabled', False) + and self.topk_type == "sglang" + and not torch.cuda.is_current_stream_capturing() + ): + eff_bs = ctx.batch_size * ctx.num_kv_heads + self.last_histograms = torch.zeros(eff_bs, 256, dtype=torch.int32, device=x.device) + topk_profile_histogram( + x, + ctx.dense_kv_indptr, + self.last_histograms, + eff_bs, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + mapping_mode, + mapping_power, + mapping_lut, + mapping_quantiles, + ) + # Accumulate histograms for offline calibration + _calibration_histograms.append(self.last_histograms.cpu().clone()) + return o From 31ba23ba830fb64bde97f9145654a5dfb28cd1a4 Mon Sep 17 00:00:00 2001 From: UED Date: Wed, 1 Apr 2026 08:12:40 +0000 Subject: [PATCH 14/24] Enhance TopK mapping modes with new remap functions - Removed outdated GPU architecture flags from setup.py. - Added new mapping modes (Erf, Tanh, Subtract) to analyze_topk_distribution.py and bench_topk.py. - Updated functions to handle new modes and added support for noscale parameters in autotune and benchmark scripts. - Enhanced the TopK kernel with additional profiling metrics and improved handling of kernel arguments. - Updated example scripts to reflect new modes and parameters for distribution analysis. --- CLAUDE.md | 172 +++++++ benchmarks/analyze_topk_distribution.py | 25 +- benchmarks/autotune_topk_mapping.py | 336 +++++++++++-- benchmarks/bench_topk.py | 338 ++++++++++++-- csrc/clean.py | 21 + csrc/register.cc | 29 +- csrc/register.h | 43 +- csrc/topk_mapping.cuh | 44 +- csrc/topk_sglang.cu | 416 ++++++++++++++++- csrc/topk_slgang_ori.cu | 546 ++++++++++++++++++++++ examples/run_distribution_analysis.sh | 105 ++++- examples/run_distribution_analysis_new.sh | 23 +- examples/run_topk_benchmark.sh | 32 +- examples/verify_algo.py | 4 +- examples/verify_algo_topk_mapping.sh | 173 +++++-- examples/verify_algo_topk_mapping_new.sh | 209 ++++++--- setup.py | 2 - third_party/sglang | 2 +- todo.txt | 308 ++++++++++++ vortex_torch/indexer/context.py | 3 + vortex_torch/indexer/output_func.py | 3 + 21 files changed, 2601 insertions(+), 233 deletions(-) create mode 100644 CLAUDE.md create mode 100644 csrc/clean.py create mode 100644 csrc/topk_slgang_ori.cu create mode 100644 todo.txt diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..585a246f --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,172 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. + +## Build & Install + +```bash +# Clone with submodules +git clone -b v1 --recursive + +# Install SGLang dependency (custom fork in third_party/, supports v0.4.9) +cd third_party/sglang && bash install.sh && cd ../../ + +# Install Vortex (editable mode, compiles CUDA extensions for SM_86/SM_89/SM_90) +pip install -e . +``` + +Requires Python >=3.10, torch>=2.7, lighteval[math]==0.12.2. CUDA extensions (`vortex_torch_C`) are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu, topk_sglang.cu). + +## Testing & Verification + +There is no formal test suite (no pytest). Verification is done by running algorithms against SGLang reference output and comparing accuracy on math benchmarks. + +```bash +# Single algorithm verification (from examples/ directory) +python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention + +# Full options +python examples/verify_algo.py \ + --trials 8 --topk-val 30 \ + --vortex-module-name block_sparse_attention \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type naive \ + --mem 0.7 + +# Batch test (outputs timestamped logs to examples/results/) +bash examples/verify_algo.sh + +# AIM24 benchmark verification +python examples/verify_aim24.py +``` + +Available `--topk-type` values: `naive` (CUB-based), `sglang` (SGLang-integrated kernel). + +## AI-Powered Algorithm Generation + +```bash +# Generate new sparse attention algorithms via OpenHands (requires LLM_API_KEY env var) +python openhands_gen.py +``` + +Note: Some auto-generated operators may not be fully optimized. Tune `mem_fraction_static` if OOM occurs. + +## Building Documentation + +```bash +make -C docs html +``` + +Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. + +## Architecture + +### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) + +All sparse attention algorithms inherit from `vFlow` and implement three methods: + +- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. +- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. +- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. + +Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. + +### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) + +Operators (`vOp` subclasses) run in two modes: +- **Profile mode**: Pre-compute output shapes and allocate buffers +- **Execute mode**: Perform actual GPU computation + +Operators are split into two parallel hierarchies: +- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load +- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup + +Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. + +### Tensor Format (`vortex_torch/abs/tensor.py`) + +`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. + +### Context System (`vortex_torch/abs/context_base.py`) + +`ContextBase` carries per-step runtime state. Specialized as: +- `Indexer.Context`: Page layout, head config, hardware info +- `Cache.Context`: Page size, total pages, model info + +### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) + +- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) +- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation +- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds + +### Algorithm Registry (`vortex_torch/flow/registry.py`) + +Algorithms are registered via `@register("name")` and looked up with `get(name)`, `has(name)`, `list_keys()`. Factory: `build_vflow(name)` in `loader.py`. + +### SGLang Integration + +Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, transpose operations (NH↔HN), and top-K output routing. + +## Key Conventions + +- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` +- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) +- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` +- **Branch**: Main development is on `v1` + +## Workflow Orchestration + +### 1. Plan Node Default +- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) +- If something goes sideways, STOP and re-plan immediately - don't keep pushing +- Use plan mode for verification steps, not just building +- Write detailed specs upfront to reduce ambiguity + +### 2. Subagent Strategy +- Use subagents liberally to keep main context window clean +- Offload research, exploration, and parallel analysis to subagents +- For complex problems, throw more compute at it via subagents +- One tack per subagent for focused execution + +### 3. Self-Improvement Loop +- After ANY correction from the user: update `tasks/lessons.md` with the pattern +- Write rules for yourself that prevent the same mistake +- Ruthlessly iterate on these lessons until mistake rate drops +- Review lessons at session start for relevant project + +### 4. Verification Before Done +- Never mark a task complete without proving it works +- Diff behavior between main and your changes when relevant +- Ask yourself: "Would a staff engineer approve this?" +- Run tests, check logs, demonstrate correctness + +### 5. Demand Elegance (Balanced) +- For non-trivial changes: pause and ask "is there a more elegant way?" +- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" +- Skip this for simple, obvious fixes - don't over-engineer +- Challenge your own work before presenting it + +### 6. Autonomous Bug Fixing +- When given a bug report: just fix it. Don't ask for hand-holding +- Point at logs, errors, failing tests - then resolve them +- Zero context switching required from the user +- Go fix failing CI tests without being told how + +## Task Management + +1. **Plan First**: Write plan to `tasks/todo.md` with checkable items +2. **Verify Plan**: Check in before starting implementation +3. **Track Progress**: Mark items complete as you go +4. **Explain Changes**: High-level summary at each step +5. **Document Results**: Add review section to `tasks/todo.md` +6. **Capture Lessons**: Update `tasks/lessons.md` after corrections + +## Core Principles + +- **Simplicity First**: Make every change as simple as possible. Impact minimal code. +- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. +- **Minimal Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/benchmarks/analyze_topk_distribution.py b/benchmarks/analyze_topk_distribution.py index 7d944667..00cdf287 100644 --- a/benchmarks/analyze_topk_distribution.py +++ b/benchmarks/analyze_topk_distribution.py @@ -40,6 +40,9 @@ 6: "Asinh", 7: "Log1p", 8: "Trunc8", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", } MAPPING_MODE_FORMULAS = { @@ -52,25 +55,33 @@ 6: "Asinh: asinh(beta*x)", 7: "Log1p: sign(x)*log1p(alpha*|x|)", 8: "Trunc8: bf16 upper-8-bit bucketing", + 9: "Erf: erf(alpha*x)", + 10: "Tanh: tanh(alpha*x)", + 11: "Subtract: x - pivot (RadiK-style)", } def _mode_key_to_display(mode_key: str) -> str: - """Convert a mode key like 'mode_3' or 'mode_3_Power' to a display name.""" + """Convert a mode key like 'mode_3', 'mode_3_Power', or 'mode_3_Power_noscale' to display name.""" + # Handle noscale suffix + noscale = mode_key.endswith("_noscale") + base_key = mode_key[:-len("_noscale")] if noscale else mode_key + suffix = " noscale" if noscale else "" + # Handle new format: "mode_3_Power" - parts = mode_key.split("_", 2) + parts = base_key.split("_", 2) if len(parts) >= 3: - return parts[2] # e.g. "Power" + return parts[2] + suffix # e.g. "Power noscale" # Handle old format: "mode_3" try: mode_num = int(parts[1]) - return MAPPING_MODE_NAMES.get(mode_num, mode_key) + return MAPPING_MODE_NAMES.get(mode_num, base_key) + suffix except (IndexError, ValueError): return mode_key def _mode_key_to_number(mode_key: str) -> int: - """Extract the mode number from a key like 'mode_3' or 'mode_3_Power'.""" + """Extract the mode number from a key like 'mode_3', 'mode_3_Power', or 'mode_3_Power_noscale'.""" parts = mode_key.split("_") try: return int(parts[1]) @@ -314,7 +325,7 @@ def plot_mapping_mode_comparison(mode_stats_data: dict, output_dir: str): x = np.arange(len(modes)) width = 0.3 - fig, ax1 = plt.subplots(figsize=(10, 5)) + fig, ax1 = plt.subplots(figsize=(max(10, len(modes) * 0.8), 5)) ax2 = ax1.twinx() bars1 = ax1.bar(x - width / 2, ginis, width, label="Gini", color="darkorange") @@ -324,7 +335,7 @@ def plot_mapping_mode_comparison(mode_stats_data: dict, output_dir: str): ax1.set_ylabel("Gini") ax2.set_ylabel("Max/Mean Ratio") ax1.set_xticks(x) - ax1.set_xticklabels(mode_labels, rotation=15, ha="right") + ax1.set_xticklabels(mode_labels, rotation=30, ha="right") ax1.set_ylim(0, 1.1) ax1.set_title("Mapping Mode Comparison") diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index 9b37e32f..d95c8399 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -27,8 +27,8 @@ import numpy as np import torch -from bench_topk import make_topk_inputs, compute_histogram_stats -from vortex_torch_C import topk_profile_histogram +from bench_topk import make_topk_inputs, bench_kernel, compute_histogram_stats +from vortex_torch_C import topk_profile_histogram, topk_profile_counters, topk_output_sglang @@ -37,10 +37,22 @@ 3: ("power_exp", [0.1, 0.25, 0.75, 0.9]), 6: ("beta", [0.1, 0.5, 1.0, 2.0, 4.0]), 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), + 9: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), + 10: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), } BASELINES = { 0: ("none", 0.5), 4: ("log", 0.5), + 8: ("trunc8", 0.5), + 11: ("subtract", 0.5), +} +# Noscale baselines for parametric transform modes (skip auto-range pre-pass) +NOSCALE_BASELINES = { + 3: ("power_noscale", [0.5]), + 6: ("asinh_noscale", [1.0]), + 7: ("log1p_noscale", [1.0]), + 9: ("erf_noscale", [1.0]), + 10: ("tanh_noscale", [1.0]), } MODE_NAMES = { 0: "none", @@ -48,6 +60,10 @@ 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", + 9: "erf", + 10: "tanh", + 11: "subtract", } @@ -106,6 +122,56 @@ def build_bin_range_table(): return bin_lo, bin_hi +def generate_remap_lut(mode: int, param: float) -> np.ndarray: + """Generate a 256-entry uint8 LUT that approximates a transform mode. + + For each of the 256 fp16 radix bins, compute the transform of the + bin's midpoint value, then linearly map transformed values to [0,255]. + The resulting LUT can be used with mode=1 (LUT CDF) infrastructure, + replacing expensive per-element transcendental math with a single + shared memory lookup. + + Args: + mode: TopKMappingMode (3=Power, 4=Log, 6=Asinh, 7=Log1p, 9=Erf, 10=Tanh) + param: power_exp/beta/alpha for the transform + + Returns: + lut: [256] uint8 array mapping original_bin -> remapped_bin + """ + bin_lo, bin_hi = build_bin_range_table() + midpoints = (bin_lo + bin_hi) / 2.0 # [256] float32 + + # Apply transform + if mode == 3: # power + transformed = np.sign(midpoints) * np.abs(midpoints) ** param + elif mode == 4: # log + transformed = np.sign(midpoints) * np.log(np.abs(midpoints) + 1.0) + elif mode == 6: # asinh + transformed = np.arcsinh(param * midpoints) + elif mode == 7: # log1p + transformed = np.sign(midpoints) * np.log1p(param * np.abs(midpoints)) + elif mode == 9: # erf + from scipy.special import erf + transformed = erf(param * midpoints) + elif mode == 10: # tanh + transformed = np.tanh(param * midpoints) + else: + # Identity fallback + transformed = midpoints.copy() + + # Handle NaN/Inf from edge cases + transformed = np.nan_to_num(transformed, nan=0.0, posinf=0.0, neginf=0.0) + + # Linear map to [0, 255] + tmin, tmax = transformed.min(), transformed.max() + if tmax > tmin: + lut = np.clip(((transformed - tmin) / (tmax - tmin) * 255), 0, 255).astype(np.uint8) + else: + lut = np.full(256, 128, dtype=np.uint8) + + return lut + + def scores_from_histogram( histogram: np.ndarray, total_pages: int, @@ -232,7 +298,8 @@ def run_sweep(args) -> List[dict]: eff_bs = inputs["eff_batch_size"] - def evaluate(mode: int, power: float, label: str): + def evaluate(mode: int, power: float, label: str, noscale: bool = False, + lut_tensor=None): hists = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") topk_profile_histogram( inputs["x"], @@ -243,28 +310,62 @@ def evaluate(mode: int, power: float, label: str): args.reserved_eos, mode, power, - None, # lut + lut_tensor, # lut None, # quantiles + noscale, ) torch.cuda.synchronize() stats = compute_histogram_stats(hists) - return { + result = { "label": label, "mode": mode, "mode_name": MODE_NAMES.get(mode, f"m{mode}"), "param": power, + "noscale": noscale, "distribution": dist, "gini": stats["gini"], "max_mean_ratio": stats["max_mean_ratio"], "num_nonzero_bins": stats["num_nonzero_bins"], } + # Counter-based metrics (Stage 2 cost analysis) + if args.counters: + inputs["sparse_kv_indices"].zero_() + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + topk_profile_counters( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + inputs["num_pages_per_seg"], + mode, + power, + lut_tensor, # lut + None, # quantiles + noscale, + ) + torch.cuda.synchronize() + c = counter_buf.float() + result["num_equal_mean"] = c[:, 2].mean().item() + result["remaining_k_mean"] = c[:, 3].mean().item() + result["refine_rounds_mean"] = c[:, 4].mean().item() + result["stage2_input_mean"] = c[:, 5].mean().item() + result["res_rate_mean"] = (c[:, 3] == 0).float().mean().item() + + return result + # Baselines for mode, (name, default_power) in BASELINES.items(): r = evaluate(mode, default_power, f"m{mode}_{name}") results.append(r) - # Parametric sweep + # Parametric sweep (scaled) for mode, (param_name, values) in SWEEP_GRID.items(): mname = MODE_NAMES[mode] for val in values: @@ -272,44 +373,122 @@ def evaluate(mode: int, power: float, label: str): r = evaluate(mode, val, label) results.append(r) + # Noscale sweep for parametric modes + for mode, (name, values) in NOSCALE_BASELINES.items(): + mname = MODE_NAMES[mode] + for val in values: + label = f"m{mode}_{mname}_noscale_{val}" + r = evaluate(mode, val, label, noscale=True) + results.append(r) + + # LUT approximation sweep: generate a LUT for each (mode, param) and + # evaluate via mode=1 (LUT CDF). This replaces per-element transcendentals + # with a single shared memory lookup. + if args.lut_sweep: + lut_modes = { + 3: [0.25, 0.5, 0.75], + 6: [0.5, 1.0, 2.0], + 7: [0.5, 1.0, 2.0], + 9: [0.5, 1.0, 2.0], + 10: [0.5, 1.0, 2.0], + } + for src_mode, params in lut_modes.items(): + src_name = MODE_NAMES[src_mode] + for p in params: + try: + lut_np = generate_remap_lut(src_mode, p) + lut_t = torch.from_numpy(lut_np).cuda() + label = f"lut_{src_name}_{p}" + # Evaluate as mode=1 (LUT CDF) with the generated LUT + r = evaluate(1, 0.5, label, lut_tensor=lut_t) + r["lut_source_mode"] = src_mode + r["lut_source_param"] = p + results.append(r) + except ImportError: + # scipy not available for erf + pass + return results -def print_table(results: List[dict]): +def print_table(results: List[dict], show_latency: bool = False): """Print ranked results as a formatted table.""" - # Sort by Gini ascending (lower = more uniform = better) - ranked = sorted(results, key=lambda r: r["gini"]) + has_counters = any("res_rate_mean" in r for r in results) + has_latency = any("full_kernel_ms" in r for r in results) - header = ( - f"{'Rank':>4s} {'Label':<35s} {'Dist':<12s} " - f"{'Gini':>6s} {'Max/Mean':>8s} {'NZBins':>6s}" - ) - print("\n" + "=" * len(header)) - print("TopK Mapping Auto-Tune Results (ranked by Gini, lower=better)") - print("=" * len(header)) - print(header) - print("-" * len(header)) + # Primary ranking: by res_rate_mean (higher=better) if counters, else by gini (lower=better) + if has_counters: + ranked = sorted(results, key=lambda r: -r.get("res_rate_mean", 0.0)) + rank_label = "ranked by res_rate, higher=better" + else: + ranked = sorted(results, key=lambda r: r["gini"]) + rank_label = "ranked by Gini, lower=better" + + # Build header + cols = f"{'Rank':>4s} {'Label':<35s} {'Dist':<12s} {'Gini':>6s} {'Max/Mean':>8s} {'NZBins':>6s}" + if has_counters: + cols += f" {'ResRate':>7s} {'RemK':>5s} {'Rnds':>4s} {'S2In':>5s}" + if has_latency and show_latency: + cols += f" {'LatMs':>9s} {'LatRk':>5s}" + + print(f"\n{'=' * len(cols)}") + print(f"TopK Mapping Auto-Tune Results ({rank_label})") + print("=" * len(cols)) + print(cols) + print("-" * len(cols)) for i, r in enumerate(ranked): - print( - f"{i+1:4d} {r['label']:<35s} {r['distribution']:<12s} " + noscale_tag = " [NS]" if r.get("noscale", False) else "" + line = ( + f"{i+1:4d} {r['label'] + noscale_tag:<35s} {r['distribution']:<12s} " f"{r['gini']:6.3f} " f"{r['max_mean_ratio']:8.2f} {r['num_nonzero_bins']:6d}" ) - - print("=" * len(header)) + if has_counters: + rr = r.get("res_rate_mean", 0.0) + rk = r.get("remaining_k_mean", 0.0) + rnds = r.get("refine_rounds_mean", 0.0) + s2in = r.get("stage2_input_mean", 0.0) + line += f" {rr:7.3f} {rk:5.0f} {rnds:4.1f} {s2in:5.0f}" + if has_latency and show_latency: + lat = r.get("full_kernel_ms", float("nan")) + lat_rank = r.get("latency_rank", "-") + line += f" {lat:9.4f} {lat_rank:>5s}" if isinstance(lat_rank, str) else f" {lat:9.4f} {lat_rank:5d}" + print(line) + + print("=" * len(cols)) if ranked: best = ranked[0] - print( + msg = ( f"\nBest overall: {best['label']} (dist={best['distribution']}) " f"— gini={best['gini']:.3f}, max/mean={best['max_mean_ratio']:.2f}" ) + if has_counters: + msg += f", res_rate={best.get('res_rate_mean', 0):.3f}" + if "full_kernel_ms" in best: + msg += f", latency={best['full_kernel_ms']:.4f}ms" + print(msg) + + # If latency data available, also print best by latency + if has_latency and show_latency: + lat_ranked = sorted([r for r in results if "full_kernel_ms" in r], + key=lambda r: r["full_kernel_ms"]) + if lat_ranked: + best_lat = lat_ranked[0] + print( + f"Best by latency: {best_lat['label']} (dist={best_lat['distribution']}) " + f"— latency={best_lat['full_kernel_ms']:.4f}ms, gini={best_lat['gini']:.3f}" + ) - # Per-mode best summary (lowest gini per mode) + # Per-mode best summary mode_best = {} for r in results: m = r["mode"] - if m not in mode_best or r["gini"] < mode_best[m]["gini"]: + if has_counters: + is_better = m not in mode_best or r.get("res_rate_mean", 0) > mode_best[m].get("res_rate_mean", 0) + else: + is_better = m not in mode_best or r["gini"] < mode_best[m]["gini"] + if is_better: mode_best[m] = r if mode_best: @@ -322,12 +501,94 @@ def print_table(results: List[dict]): param_str = f"{param_name}={r['param']}" else: param_str = "(baseline)" + ns_str = " noscale" if r.get("noscale", False) else "" + lat_str = f" latency={r['full_kernel_ms']:.4f}ms" if "full_kernel_ms" in r else "" + counter_str = f" res_rate={r.get('res_rate_mean', 0):.3f}" if has_counters else "" print( - f" Mode {m:d} ({mname:>5s}): {param_str:<20s} " - f"gini={r['gini']:.3f} max/mean={r['max_mean_ratio']:.2f}" + f" Mode {m:d} ({mname:>5s}{ns_str}): {param_str:<20s} " + f"gini={r['gini']:.3f} max/mean={r['max_mean_ratio']:.2f}{counter_str}{lat_str}" ) +def latency_rerank(results: List[dict], args) -> List[dict]: + """Re-rank top Gini candidates by actual kernel latency.""" + # Sort by Gini, take top N + ranked = sorted(results, key=lambda r: r["gini"]) + finalists = ranked[:args.latency_top_n] + + print(f"\n--- Latency re-ranking: timing top {len(finalists)} Gini finalists ---") + + # Build inputs for latency measurement + real_histogram = None + if args.real_histograms: + raw = np.load(args.real_histograms) + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + + if real_histogram is not None: + inputs = make_real_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + histogram=real_histogram, + ) + else: + inputs = make_topk_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution="normal", + ) + + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + + for r in finalists: + inputs["sparse_kv_indices"].zero_() + # For LUT-generated entries, regenerate the LUT tensor + lut_tensor = None + if "lut_source_mode" in r: + lut_np = generate_remap_lut(r["lut_source_mode"], r["lut_source_param"]) + lut_tensor = torch.from_numpy(lut_np).cuda() + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + r["mode"], + r["param"], + lut_tensor, # lut + None, # quantiles + r.get("noscale", False), + ) + latency = bench_kernel(topk_output_sglang, call_args, + warmup=10, repeat=args.latency_repeat) + r["full_kernel_ms"] = latency["mean_ms"] + print(f" {r['label']:<35s} gini={r['gini']:.3f} latency={latency['mean_ms']:.4f}ms") + + # Re-rank finalists by latency + finalists.sort(key=lambda r: r["full_kernel_ms"]) + for i, r in enumerate(finalists): + r["latency_rank"] = i + 1 + r["gini_rank"] = next(j+1 for j, x in enumerate(ranked) if x is r) + + return results + + def main(): parser = argparse.ArgumentParser( description="Auto-tune TopK mapping hyperparameters" @@ -353,6 +614,16 @@ def main(): "--output-json", type=str, default=None, help="Save results to JSON file", ) + parser.add_argument("--latency-rerank", action="store_true", + help="Re-rank top Gini finalists by actual kernel latency") + parser.add_argument("--latency-top-n", type=int, default=10, + help="Number of Gini finalists to re-rank by latency (default: 10)") + parser.add_argument("--latency-repeat", type=int, default=50, + help="Kernel timing repetitions for latency measurement (default: 50)") + parser.add_argument("--counters", action="store_true", + help="Collect counter-based metrics (Stage 2 cost analysis) for each config") + parser.add_argument("--lut-sweep", action="store_true", + help="Generate and evaluate LUT approximations for parametric transform modes") args = parser.parse_args() source = f"real ({args.real_histograms})" if args.real_histograms else f"synthetic ({args.distributions})" @@ -361,12 +632,17 @@ def main(): f"topk_val={args.topk_val}, num_kv_heads={args.num_kv_heads}") print(f" score source: {source}") n_parametric = sum(len(v) for _, v in SWEEP_GRID.values()) + n_baselines = len(BASELINES) n_dists = 1 if args.real_histograms else len(args.distributions) - print(f" sweep: {n_parametric} parametric + {len(BASELINES)} baselines " - f"= {n_parametric + len(BASELINES)} combos x {n_dists} dists") + print(f" sweep: {n_parametric} parametric + {n_baselines} baselines " + f"= {n_parametric + n_baselines} combos x {n_dists} dists") results = run_sweep(args) - print_table(results) + + if args.latency_rerank: + results = latency_rerank(results, args) + + print_table(results, show_latency=args.latency_rerank) if args.output_json: with open(args.output_json, "w") as f: diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index ca039f2f..675092e5 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -18,7 +18,10 @@ import numpy as np import torch -from vortex_torch_C import topk_output, topk_output_sglang, topk_profile_histogram +from vortex_torch_C import ( + topk_output, topk_output_sglang, topk_profile_histogram, + topk_profile_stage1, topk_profile_counters, +) # Canonical mapping mode names — used in logs, tables, and plots MAPPING_MODE_NAMES = { @@ -31,6 +34,9 @@ 6: "Asinh", 7: "Log1p", 8: "Trunc8", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", } MAPPING_MODE_FORMULAS = { @@ -43,6 +49,9 @@ 6: "Asinh: asinh(beta*x)", 7: "Log1p: sign(x)*log1p(alpha*|x|)", 8: "Trunc8: bf16 upper-8-bit bucketing", + 9: "Erf: erf(alpha*x)", + 10: "Tanh: tanh(alpha*x)", + 11: "Subtract: x - pivot (RadiK-style)", } @@ -200,7 +209,7 @@ def _load_autotune_powers(path: str) -> Dict[int, float]: best: Dict[int, dict] = {} for r in data: m = r.get("mode") - if m not in (3, 6, 7): + if m not in (3, 6, 7, 9, 10): continue if has_res_rate: score = r.get("res_rate_mean", 0.0) @@ -219,7 +228,8 @@ def _resolve_mode_power(args, mode: int) -> float: Priority: per-mode CLI flag > autotune JSON > global --mapping-power. """ - per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7} + per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7, + 9: getattr(args, 'mapping_power_9', None), 10: getattr(args, 'mapping_power_10', None)} if mode in per_mode_flag and per_mode_flag[mode] is not None: return per_mode_flag[mode] if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: @@ -276,11 +286,20 @@ def run_benchmark(args) -> List[dict]: all_kernels = { "naive": "naive", "sglang_m0": "sglang_m0", + "sglang_scale": "sglang_scale", # mode 3 with p=1.0 (identity + linear auto-range scaling) "sglang_m3": "sglang_m3", + "sglang_m3_noscale": "sglang_m3_noscale", "sglang_m4": "sglang_m4", "sglang_m6": "sglang_m6", + "sglang_m6_noscale": "sglang_m6_noscale", "sglang_m7": "sglang_m7", + "sglang_m7_noscale": "sglang_m7_noscale", "sglang_m8": "sglang_m8", + "sglang_m9": "sglang_m9", + "sglang_m9_noscale": "sglang_m9_noscale", + "sglang_m10": "sglang_m10", + "sglang_m10_noscale": "sglang_m10_noscale", + "sglang_m11": "sglang_m11", } if mapping_lut is not None: all_kernels["sglang_m1"] = "sglang_m1" @@ -288,6 +307,23 @@ def run_benchmark(args) -> List[dict]: all_kernels["sglang_m2"] = "sglang_m2" if args.filter_kernels: + # Validate: if the user explicitly requested sglang_m1 or sglang_m2 but + # the required calibration file was not provided, fail loudly instead of + # silently skipping these modes. + if "sglang_m1" in args.filter_kernels and "sglang_m1" not in all_kernels: + raise RuntimeError( + "sglang_m1 (LUT CDF) was requested in --filter-kernels but no " + "--lut-path was provided. Mode 1 requires a calibrated LUT file " + "(lut.npy from calibrate_topk.py). Either supply --lut-path or " + "remove sglang_m1 from --filter-kernels." + ) + if "sglang_m2" in args.filter_kernels and "sglang_m2" not in all_kernels: + raise RuntimeError( + "sglang_m2 (Quantile) was requested in --filter-kernels but no " + "--quantiles-path was provided. Mode 2 requires a calibrated " + "quantiles file (quantiles.npy from calibrate_topk.py). Either " + "supply --quantiles-path or remove sglang_m2 from --filter-kernels." + ) all_kernels = {k: v for k, v in all_kernels.items() if k in args.filter_kernels} # Naive kernel only supports bf16 @@ -352,12 +388,14 @@ def run_benchmark(args) -> List[dict]: "kernels": {}, } + # Collect all kernel results first, then print sorted by latency + kernel_entries = [] # [(label, kernel_name, result)] + for kernel_name in all_kernels: # Reset sparse indices each run inputs["sparse_kv_indices"].zero_() if kernel_name == "naive": - # topk_output: (x, dense_indptr, dense_indices, sparse_indptr, sparse_indices, ...) call_args = ( inputs["x"], inputs["dense_kv_indptr"], @@ -371,18 +409,39 @@ def run_benchmark(args) -> List[dict]: pages_per_seg, ) result = bench_kernel(topk_output, call_args, args.warmup, args.repeat) + elif kernel_name == "sglang_scale": + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + 3, # mode 3 (power) + 1.0, # p=1.0 → identity + None, + None, + ) + result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) else: - # Parse mapping mode from kernel name - mode = int(kernel_name.split("_m")[1]) + mode_str = kernel_name.split("_m")[1] + mode = int(mode_str.split("_")[0]) + is_noscale = kernel_name.endswith("_noscale") extra_kwargs = {} if mode == 1: extra_kwargs["mapping_lut"] = mapping_lut elif mode == 2: extra_kwargs["mapping_quantiles"] = mapping_quantiles - power = _resolve_mode_power(args, mode) if mode in (3, 6, 7) else 0.5 + if mode in (3, 6, 7, 9, 10): + power = _resolve_mode_power(args, mode) + else: + power = 0.5 - # topk_output_sglang: (x, dense_indptr, sparse_indptr, dense_indices, sparse_indices, ...) call_args = ( inputs["x"], inputs["dense_kv_indptr"], @@ -398,25 +457,175 @@ def run_benchmark(args) -> List[dict]: power, extra_kwargs.get("mapping_lut", None), extra_kwargs.get("mapping_quantiles", None), + is_noscale, ) result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) + # Build label if kernel_name == "naive": label = "naive" + elif kernel_name == "sglang_scale": + label = "sglang Scale Only (p=1.0)" else: - m = int(kernel_name.split("_m")[1]) + m_str = kernel_name.split("_m")[1] + m = int(m_str.split("_")[0]) + noscale_suffix = " noscale" if kernel_name.endswith("_noscale") else "" mname = MAPPING_MODE_NAMES.get(m, f'm{m}') - if m in (3, 6, 7): - pname = {3: "p", 6: "beta", 7: "alpha"}[m] - label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)})" + if m in (3, 6, 7, 9, 10): + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[m] + label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)}){noscale_suffix}" + else: + label = f"sglang {mname}{noscale_suffix}" + + # Sub-phase profiling for sglang kernels + if kernel_name != "naive": + if kernel_name == "sglang_scale": + s1_mode, s1_power = 3, 1.0 + s1_lut, s1_q = None, None + s1_noscale = False else: - label = f"sglang {mname}" + s1_mode_str = kernel_name.split("_m")[1] + s1_mode = int(s1_mode_str.split("_")[0]) + s1_noscale = kernel_name.endswith("_noscale") + if s1_mode in (3, 6, 7, 9, 10): + s1_power = _resolve_mode_power(args, s1_mode) + else: + s1_power = 0.5 + s1_lut = mapping_lut if s1_mode == 1 else None + s1_q = mapping_quantiles if s1_mode == 2 else None + + # Histogram only: pre-pass + histogram build + hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + hist_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + hist_buf, + eff_bs, + args.reserved_bos, + args.reserved_eos, + s1_mode, + s1_power, + s1_lut, + s1_q, + s1_noscale, + ) + hist_result = bench_kernel(topk_profile_histogram, hist_args, args.warmup, args.repeat) + + # Stage1 full: pre-pass + hist + cumsum + route/filter + inputs["sparse_kv_indices"].zero_() + stage1_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + s1_mode, + s1_power, + s1_lut, + s1_q, + s1_noscale, + ) + stage1_result = bench_kernel(topk_profile_stage1, stage1_args, args.warmup, args.repeat) + + result['histogram_only_mean_ms'] = hist_result['mean_ms'] + result['histogram_only_median_ms'] = hist_result['median_ms'] + result['stage1_full_mean_ms'] = stage1_result['mean_ms'] + result['stage1_full_median_ms'] = stage1_result['median_ms'] + result['route_overhead_mean_ms'] = stage1_result['mean_ms'] - hist_result['mean_ms'] + result['route_overhead_median_ms'] = stage1_result['median_ms'] - hist_result['median_ms'] + result['stage2_refine_mean_ms'] = result['mean_ms'] - stage1_result['mean_ms'] + result['stage2_refine_median_ms'] = result['median_ms'] - stage1_result['median_ms'] + + # Optional counter collection + if args.counters: + inputs["sparse_kv_indices"].zero_() + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + counter_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + s1_mode, + s1_power, + s1_lut, + s1_q, + s1_noscale, + ) + topk_profile_counters(*counter_args) + torch.cuda.synchronize() + c = counter_buf.float() + result['counters'] = { + 'threshold_bin_mean': c[:, 0].mean().item(), + 'num_above_mean': c[:, 1].mean().item(), + 'num_equal_mean': c[:, 2].mean().item(), + 'remaining_k_mean': c[:, 3].mean().item(), + 'refine_rounds_mean': c[:, 4].mean().item(), + 'stage2_input_mean': c[:, 5].mean().item(), + 'threshold_bin_max': c[:, 0].max().item(), + 'num_above_max': c[:, 1].max().item(), + 'num_equal_max': c[:, 2].max().item(), + 'remaining_k_max': c[:, 3].max().item(), + 'refine_rounds_max': c[:, 4].max().item(), + 'stage2_input_max': c[:, 5].max().item(), + } + + kernel_entries.append((label, kernel_name, result)) + config_results["kernels"][kernel_name] = result + + # Print kernel results sorted by mean latency (ascending) + kernel_entries.sort(key=lambda e: e[2]['mean_ms']) + print(f" --- kernel latency (sorted by mean, ascending) ---") + for label, kernel_name, result in kernel_entries: print( - f" {label:<30s}: {result['median_ms']:.4f}ms (median) " + f" {label:<40s}: " + f"mean={result['mean_ms']:.4f}ms " + f"median={result['median_ms']:.4f}ms " f"\u00b1 {result['std_ms']:.4f}ms " f"[min={result['min_ms']:.4f}, max={result['max_ms']:.4f}]" ) - config_results["kernels"][kernel_name] = result + if 'stage1_full_mean_ms' in result: + print( + f" {'Histogram only (map+hist)':<36s}: " + f"mean={result['histogram_only_mean_ms']:.4f}ms " + f"median={result['histogram_only_median_ms']:.4f}ms" + ) + print( + f" {'Stage1 full (hist+cumsum+route)':<36s}: " + f"mean={result['stage1_full_mean_ms']:.4f}ms " + f"median={result['stage1_full_median_ms']:.4f}ms" + ) + print( + f" {'Route overhead (cumsum+route)':<36s}: " + f"mean={result['route_overhead_mean_ms']:.4f}ms " + f"median={result['route_overhead_median_ms']:.4f}ms" + ) + print( + f" {'Stage2 (refine)':<36s}: " + f"mean={result['stage2_refine_mean_ms']:.4f}ms " + f"median={result['stage2_refine_median_ms']:.4f}ms" + ) + if 'counters' in result: + c = result['counters'] + print( + f" Counters: threshold_bin={c['threshold_bin_mean']:.0f} " + f"above={c['num_above_mean']:.0f} " + f"equal={c['num_equal_mean']:.0f} " + f"remaining_k={c['remaining_k_mean']:.0f} " + f"refine_rounds={c['refine_rounds_mean']:.1f} " + f"stage2_input={c['stage2_input_mean']:.0f}" + ) # Histogram analysis if args.histogram: @@ -476,22 +685,25 @@ def run_benchmark(args) -> List[dict]: f"nonzero_bins={hstats['num_nonzero_bins']}/256" ) - # Per-mode histogram analysis - modes_to_test = [0, 3, 4, 6, 7, 8] + # Collect all histogram entries, then print sorted by gini + # Each entry: (display_name, key, mode_stats) + hist_entries = [] + histograms_results = {} + + # Per-mode histogram analysis (scaled) + modes_to_test = [0, 3, 4, 6, 7, 8, 9, 10, 11] if mapping_lut is not None: modes_to_test.append(1) if mapping_quantiles is not None: modes_to_test.append(2) modes_to_test.sort() - histograms_results = {} - print(f" --- histogram by mapping mode ---") for mode in modes_to_test: mode_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") extra_lut = mapping_lut if mode == 1 else None extra_q = mapping_quantiles if mode == 2 else None - power = _resolve_mode_power(args, mode) if mode in (3, 6, 7) else 0.5 + power = _resolve_mode_power(args, mode) if mode in (3, 6, 7, 9, 10) else 0.5 topk_profile_histogram( hist_inputs["x"], @@ -513,23 +725,84 @@ def run_benchmark(args) -> List[dict]: mformula = MAPPING_MODE_FORMULAS.get(mode, mname) mode_stats["name"] = mname mode_stats["formula"] = mformula - if mode in (3, 6, 7): - pname = {3: "p", 6: "beta", 7: "alpha"}[mode] + if mode in (3, 6, 7, 9, 10): + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[mode] mode_stats["param"] = f"{pname}={power}" - histograms_results[f"mode_{mode}_{mname}"] = mode_stats - if mode in (3, 6, 7): - pname = {3: "p", 6: "beta", 7: "alpha"}[mode] display_name = f"{mname} ({pname}={power})" else: display_name = mname + key = f"mode_{mode}_{mname}" + histograms_results[key] = mode_stats + hist_entries.append((display_name, f"mode {mode:2d}", mode_stats)) + + # Noscale histogram analysis for parametric transform modes + noscale_modes = [m for m in (3, 6, 7, 9, 10) if m in modes_to_test] + for mode in noscale_modes: + ns_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") + power = _resolve_mode_power(args, mode) + topk_profile_histogram( + hist_inputs["x"], + hist_inputs["dense_kv_indptr"], + ns_hists, + hist_eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + None, + None, + True, # mapping_noscale=True + ) + torch.cuda.synchronize() + ns_stats = compute_histogram_stats(ns_hists) + ns_stats["raw_counts"] = ns_hists.sum(dim=0).tolist() + mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") + mformula = MAPPING_MODE_FORMULAS.get(mode, mname) + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[mode] + ns_stats["name"] = f"{mname} noscale" + ns_stats["formula"] = mformula + ns_stats["param"] = f"{pname}={power}" + display_name = f"{mname} noscale ({pname}={power})" + key = f"mode_{mode}_{mname}_noscale" + histograms_results[key] = ns_stats + hist_entries.append((display_name, f"m{mode:2d} ns", ns_stats)) + + # Scale Only baseline: mode 3 with p=1.0 (identity + linear scaling) + scale_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + hist_inputs["x"], + hist_inputs["dense_kv_indptr"], + scale_hists, + hist_eff_bs, + args.reserved_bos, + args.reserved_eos, + 3, # mode 3 (power) + 1.0, # p=1.0 → identity transform + None, + None, + ) + torch.cuda.synchronize() + scale_stats = compute_histogram_stats(scale_hists) + scale_stats["raw_counts"] = scale_hists.sum(dim=0).tolist() + scale_stats["name"] = "Scale Only" + scale_stats["formula"] = "Identity + linear scaling to [0,255]" + scale_stats["param"] = "p=1.0" + histograms_results["mode_scale_Scale Only"] = scale_stats + hist_entries.append(("Scale Only (p=1.0)", "scale ", scale_stats)) + + # Print all histogram entries sorted by gini (ascending = more uniform = better) + hist_entries.sort(key=lambda e: e[2]['gini']) + print(f" --- histogram by gini (sorted, lower=better) ---") + for rank, (display_name, mode_tag, stats) in enumerate(hist_entries, 1): print( - f" {display_name:<22s} (mode {mode}): " - f"gini={mode_stats['gini']:.3f} " - f"max/mean={mode_stats['max_mean_ratio']:.2f} " - f"nonzero_bins={mode_stats['num_nonzero_bins']}/256 " - f"eff_bins={mode_stats['effective_bins']:.1f} " - f"entropy={mode_stats['entropy']:.2f}" + f" {rank:2d}. {display_name:<32s} ({mode_tag}): " + f"gini={stats['gini']:.3f} " + f"max/mean={stats['max_mean_ratio']:.2f} " + f"nonzero_bins={stats['num_nonzero_bins']}/256 " + f"eff_bins={stats['effective_bins']:.1f} " + f"entropy={stats['entropy']:.2f}" ) + config_results["histograms"] = histograms_results all_results.append(config_results) @@ -573,6 +846,9 @@ def main(): "Only used when --histogram is set.") parser.add_argument("--real-histograms", type=str, default=None, help="Path to .npy raw_histograms from calibration (adds 'real' distribution)") + parser.add_argument("--counters", action="store_true", + help="Collect diagnostic counters (threshold_bin, num_above, num_equal, " + "remaining_k, refine_rounds, stage2_input) for each sglang kernel") args = parser.parse_args() results = run_benchmark(args) diff --git a/csrc/clean.py b/csrc/clean.py new file mode 100644 index 00000000..8d258bb0 --- /dev/null +++ b/csrc/clean.py @@ -0,0 +1,21 @@ +from pathlib import Path +import sys + +def clean_one_leading_space(path: str): + p = Path(path) + text = p.read_text(encoding="utf-8") + + cleaned = "".join( + line[1:] if line.startswith(" ") else line + for line in text.splitlines(keepends=True) + ) + + p.write_text(cleaned, encoding="utf-8") + print(f"Cleaned: {p}") + +if __name__ == "__main__": + if len(sys.argv) != 2: + print("Usage: python clean_indent.py ") + sys.exit(1) + + clean_one_leading_space(sys.argv[1]) \ No newline at end of file diff --git a/csrc/register.cc b/csrc/register.cc index 00674743..0a3c11ca 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -17,7 +17,8 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_mode") = 0, py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none()); + py::arg("mapping_quantiles") = py::none(), + py::arg("mapping_noscale") = false); m.def("topk_profile_histogram", &topk_profile_histogram, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("histograms"), py::arg("eff_batch_size"), @@ -25,7 +26,31 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_mode") = 0, py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none()); + py::arg("mapping_quantiles") = py::none(), + py::arg("mapping_noscale") = false); + m.def("topk_profile_stage1", &topk_profile_stage1, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none(), + py::arg("mapping_noscale") = false); + m.def("topk_profile_counters", &topk_profile_counters, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("counters"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode") = 0, + py::arg("mapping_power") = 0.5, + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none(), + py::arg("mapping_noscale") = false); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index d4f2d8b4..1a8b8207 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -99,7 +99,8 @@ const int64_t max_seq_lengths, const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt +std::optional mapping_quantiles = std::nullopt, +const bool mapping_noscale = false ); void topk_profile_histogram( @@ -112,7 +113,45 @@ const int64_t reserved_eos, const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt +std::optional mapping_quantiles = std::nullopt, +const bool mapping_noscale = false +); + +void topk_profile_stage1( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt, +const bool mapping_noscale = false +); + +void topk_profile_counters( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& counters, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode = 0, +const double mapping_power = 0.5, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt, +const bool mapping_noscale = false ); void sglang_plan_decode_fa3( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index e3fe3a73..97bc141d 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -26,6 +26,9 @@ enum TopKMappingMode { MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing }; struct TopKMappingParams { @@ -33,6 +36,8 @@ struct TopKMappingParams { float power_exp; // For MAPPING_POWER (default 0.5) const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr + bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) + int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) }; // NOTE: convert_to_uint8() must be defined before including this header. @@ -56,6 +61,14 @@ __device__ __forceinline__ float transform_log1p(float x, float alpha) { return copysignf(log1pf(alpha * fabsf(x)), x); } +__device__ __forceinline__ float transform_erf(float x, float alpha) { + return erff(alpha * x); +} + +__device__ __forceinline__ float transform_tanh(float x, float alpha) { + return tanhf(alpha * x); +} + // ---- Transform dispatcher (returns float, no bucketing) ---- __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { @@ -64,6 +77,8 @@ __device__ __forceinline__ float apply_transform(float x, const TopKMappingParam case MAPPING_LOG: return transform_log(x); case MAPPING_ASINH: return transform_asinh(x, params.power_exp); case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); default: return x; } } @@ -75,14 +90,16 @@ __device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_mi return static_cast(min(max(bin, 0), 255)); } -// ---- BF16 upper-8-bit bucketing (mode 8) ---- +// ---- BF16-aware bucketing (mode 8) ---- +// BF16 has 8 exponent + 7 mantissa bits. Taking the upper 8 bits of the +// sign-flipped bf16 bit-pattern yields only ~20 distinct bins for typical +// data (the byte is almost entirely exponent). Instead, convert through +// fp16 (5 exp + 10 mantissa) which puts 5 exp + 2 mantissa bits in the +// upper byte, giving ~135+ distinct bins — equivalent to mode 0 but +// explicitly available as a named mode for documentation/benchmarking. __device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { - __nv_bfloat16 bf = __float2bfloat16_rn(x); - uint16_t bits = __bfloat16_as_ushort(bf); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) - : static_cast(bits | 0x8000); - return static_cast(key >> 8); + return convert_to_uint8(x); // fp16 sign-flip bucketing } // ---- Non-transform mapping functions (unchanged) ---- @@ -130,12 +147,17 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( case MAPPING_POWER: case MAPPING_LOG: case MAPPING_ASINH: - case MAPPING_LOG1P: { + case MAPPING_LOG1P: + case MAPPING_ERF: + case MAPPING_TANH: { float val = apply_transform(x, params); + if (params.noscale) return convert_to_uint8(val); return linear_map_to_uint8(val, range_min, inv_range); } case MAPPING_TRUNC8: return convert_to_uint8_bf16(x); + case MAPPING_SUBTRACT: + return convert_to_uint8(x - range_min); // range_min repurposed as pivot default: // MAPPING_NONE return convert_to_uint8(x); } @@ -144,5 +166,11 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( // Helper: check if a mapping mode needs the auto-range pre-pass __device__ __forceinline__ bool needs_auto_range(int mode) { return (mode == MAPPING_POWER || mode == MAPPING_LOG || - mode == MAPPING_ASINH || mode == MAPPING_LOG1P); + mode == MAPPING_ASINH || mode == MAPPING_LOG1P || + mode == MAPPING_ERF || mode == MAPPING_TANH); +} + +// Helper: check if a mapping mode needs the pivot pre-pass +__device__ __forceinline__ bool needs_pivot(int mode) { + return (mode == MAPPING_SUBTRACT); } diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 59592708..9213016a 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -452,18 +452,30 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) constexpr int VORTEX_MAX_TOPK = 2048; +// Per-segment diagnostic counters written by WriteCounters mode +constexpr int COUNTER_THRESHOLD_BIN = 0; // Stage 1 coarse threshold bin id +constexpr int COUNTER_NUM_ABOVE = 1; // elements routed above threshold in Stage 1 +constexpr int COUNTER_NUM_EQUAL = 2; // elements in threshold bin (Stage 2 input) +constexpr int COUNTER_REMAINING_K = 3; // topk slots remaining after Stage 1 routing +constexpr int COUNTER_REFINE_ROUNDS = 4; // Stage 2 rounds used (0 = resolved in Stage 1) +constexpr int COUNTER_STAGE2_INPUT = 5; // candidates entering first Stage 2 refine round +constexpr int NUM_TOPK_COUNTERS = 6; + // Templated version of fast_topk_cuda_tl: // - ScoreT: float or __nv_bfloat16 +// - StopAfterStage1: return after Stage 1 route/filter (for profiling) +// - WriteCounters: write diagnostic counters to global memory // - target_k: runtime parameter (replaces compile-time TopK) // - mapping: configurable value-remapping for Stage 1 bin assignment -template +template __device__ void fast_topk_vortex( const ScoreT* __restrict__ input, int* __restrict__ index, int row_start, int length, int target_k, - const TopKMappingParams& mapping) + const TopKMappingParams& mapping, + int* counters = nullptr) { int topk = target_k; constexpr auto BLOCK_SIZE = 1024; @@ -497,10 +509,14 @@ __device__ void fast_topk_vortex( __syncthreads(); } - // Pre-pass: compute per-block min/max of transformed values for linear bucketing - if (needs_auto_range(mapping.mode)) { + // Pre-pass: compute per-block min/max of transformed values for linear bucketing. + // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; + // the approximated range may miss extreme outliers but Stage 2 uses raw + // float bits for exact ordering, so correctness is preserved. + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); local_min = fminf(local_min, val); local_max = fmaxf(local_max, val); @@ -528,12 +544,46 @@ __device__ void fast_topk_vortex( } } __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean of all elements, store in s_range_min. + // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering + // around the mean helps distribute values more evenly across bins. + float local_sum = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(input[idx + row_start]); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(length); // mean as pivot + s_range_inv_range = 0.0f; + } + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); } // Stage 1: 8-bit coarse histogram (with optional mapping) + // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) + // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. + // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. + constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries + uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); + const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); + if (tx < RADIX + 1) vh_histogram[tx] = 0; __syncthreads(); @@ -543,6 +593,9 @@ __device__ void fast_topk_vortex( mapping, s_mapping_lut, s_mapping_quantiles, s_range_min, s_range_inv_range); ::atomicAdd(&vh_histogram[bin], 1); + if (use_bin_cache) { + bin_cache[idx] = bin; + } } __syncthreads(); @@ -574,19 +627,35 @@ __device__ void fast_topk_vortex( const auto threshold_bin = vh_threshold_bin_id; topk -= vh_histogram[threshold_bin + 1]; + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin; + counters[COUNTER_REMAINING_K] = topk; + } + if (topk == 0) { for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast( - mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } if (bin > threshold_bin) { const auto pos = ::atomicAdd(&vh_counter, 1); index[pos] = idx; } } __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = 0; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } return; } else { __syncthreads(); @@ -595,10 +664,15 @@ __device__ void fast_topk_vortex( for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast( - mapped_convert_to_uint8(raw_input, mapping, - s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } if (bin > threshold_bin) { const auto pos = ::atomicAdd(&vh_counter, 1); index[pos] = idx; @@ -613,9 +687,19 @@ __device__ void fast_topk_vortex( } } __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; + counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; + } + if (StopAfterStage1) return; } // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) + if constexpr (WriteCounters) { + // Default: all 4 rounds used; overwritten at break if resolved early + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + } #pragma unroll 4 for (int round = 0; round < 4; ++round) { __shared__ int vh_last_remain; @@ -649,6 +733,11 @@ __device__ void fast_topk_vortex( } } __syncthreads(); + if constexpr (WriteCounters) { + if (tx == 0 && counters) { + counters[COUNTER_REFINE_ROUNDS] = round + 1; + } + } break; } else { __syncthreads(); @@ -722,6 +811,92 @@ void TopKOutput_Kernel( } } +// ====================================================================== +// Profiling Stage1 kernel: runs pre-pass + hist + cumsum + route/filter, +// stops before Stage 2 refinement (for sub-phase timing) +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKStage1_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex( + score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// ====================================================================== +// Profiling counters kernel: runs full pipeline + writes diagnostic +// counters to a separate global-memory tensor +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKCounters_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + int* __restrict__ counters, // [eff_batch_size, NUM_TOPK_COUNTERS] + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex( + score_blk, s_indices, 0, nblk, topk_val, mapping, + counters + bx * NUM_TOPK_COUNTERS); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + // ====================================================================== // Profiling histogram kernel: runs only Stage 1 and returns per-segment // 256-bin histograms for distribution analysis @@ -762,10 +937,11 @@ void TopKHistogram_Kernel( __syncthreads(); } - // Pre-pass: compute per-block min/max for transform modes - if (needs_auto_range(mapping.mode)) { + // Pre-pass: compute per-block min/max for transform modes (supports sampled stride) + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + for (int idx = tx * stride; idx < nblk; idx += BLOCK_SIZE * stride) { float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); local_min = fminf(local_min, val); local_max = fmaxf(local_max, val); @@ -791,6 +967,30 @@ void TopKHistogram_Kernel( } } __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean for MAPPING_SUBTRACT + float local_sum = 0.0f; + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(score_blk[idx]); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums_h[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums_h[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums_h[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(nblk); + s_range_inv_range = 0.0f; + } + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); @@ -949,7 +1149,8 @@ void topk_output_sglang( const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles) + std::optional mapping_quantiles, + const bool mapping_noscale) { TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, "topk_output: topk_val (", topk_val, @@ -961,6 +1162,8 @@ void topk_output_sglang( mapping.power_exp = static_cast(mapping_power); mapping.lut = nullptr; mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = 1; if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1029,7 +1232,8 @@ void topk_profile_histogram( const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles) + std::optional mapping_quantiles, + const bool mapping_noscale) { CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); @@ -1046,6 +1250,8 @@ void topk_profile_histogram( mapping.power_exp = static_cast(mapping_power); mapping.lut = nullptr; mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = 1; if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1093,3 +1299,175 @@ void topk_profile_histogram( "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); } +// Helper: build TopKMappingParams from host arguments +static TopKMappingParams build_mapping_params( + int64_t mapping_mode, double mapping_power, + std::optional& mapping_lut, + std::optional& mapping_quantiles, + bool mapping_noscale = false, + int sample_stride = 1) +{ + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = sample_stride; + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + return mapping; +} + +// ====================================================================== +// Profiling: Stage 1 only (pre-pass + hist + cumsum + route/filter) +// ====================================================================== +void topk_profile_stage1( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_stage1: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, mapping_noscale); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKStage1_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKStage1_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_stage1: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_stage1 kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling: full pipeline + diagnostic counters +// ====================================================================== +void topk_profile_counters( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& counters, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_counters: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + CHECK_CUDA(counters); + TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size + && counters.size(1) == NUM_TOPK_COUNTERS, + "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); + TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, + "counters must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, mapping_noscale); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKCounters_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKCounters_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_counters: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); +} + diff --git a/csrc/topk_slgang_ori.cu b/csrc/topk_slgang_ori.cu new file mode 100644 index 00000000..04a2b73b --- /dev/null +++ b/csrc/topk_slgang_ori.cu @@ -0,0 +1,546 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 287c4545..3022dda7 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -5,11 +5,13 @@ # Profiles the SGLang TopK kernel's first-pass bucket distribution # to identify hotspot buckets causing tail latency. # -# Three steps: -# 1. Calibrate — collect real-data histograms -# (skippable via --real-histograms PATH) -# 2. Bench — histogram profiling (bucket_uniform + normal) -# 3. Analyze — comparison plots + bucket count tables +# Four steps: +# 1. Calibrate — collect real-data histograms +# (skippable via --real-histograms PATH) +# 2. Auto-tune — sweep hyperparameters to find best per-mode power +# 3. Bench — histogram profiling (bucket_uniform + normal) +# noscale kernels use the same autotuned power +# 4. Analyze — comparison plots + bucket count tables # # All outputs (JSON, plots, CSV tables, logs) are written to a # single timestamped folder under examples/results/dist_analysis_*. @@ -30,6 +32,9 @@ # 6: Asinh — y = asinh(beta * x) # 7: Log1p — y = sign(x) * log1p(alpha * |x|) # 8: Trunc8 — bf16 upper-8-bit bucketing +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 11: Subtract — x - pivot (RadiK-style scatter) set -euo pipefail @@ -37,14 +42,14 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=5 +GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) -# REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" -REAL_HISTOGRAMS="" +REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" + # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in @@ -97,45 +102,101 @@ else echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" fi -# ── Step 2: Histogram profiling (bucket_uniform + normal) ───── +# ── Step 2: Auto-tune — sweep hyperparameters ────────────────── +echo "" +echo ">>> Step 2: Auto-tuning hyperparameters (modes 3, 6, 7, 9, 10)" + +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + +# Build autotune data source args +AUTOTUNE_EXTRA_ARGS=() +if [ -n "${REAL_HIST_PATH:-}" ]; then + AUTOTUNE_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") +fi + +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len 32768 \ + --num-kv-heads 2 \ + "${AUTOTUNE_EXTRA_ARGS[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + +echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + +# ── Step 3: Histogram profiling (bucket_uniform + normal) ───── echo "" -echo ">>> Step 2: Kernel-level histogram profiling (bucket_uniform + normal)" +echo ">>> Step 3: Kernel-level histogram profiling (bucket_uniform + normal)" BENCH_JSON="${RUN_DIR}/bench_distribution.json" -python "${BENCH_DIR}/bench_topk.py" \ +# Build optional args for bench_topk.py +BENCH_EXTRA_ARGS=() +if [ -n "${REAL_HIST_PATH:-}" ]; then + BENCH_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") +fi + +# Derive calibration directory from histogram path to find lut.npy / quantiles.npy +CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" +LUT_FILE="${CALIB_DIR}/lut.npy" +QUANTILES_FILE="${CALIB_DIR}/quantiles.npy" + +if [ -f "${LUT_FILE}" ]; then + BENCH_EXTRA_ARGS+=(--lut-path "${LUT_FILE}") + echo " Using LUT for mode 1: ${LUT_FILE}" +else + echo " WARNING: ${LUT_FILE} not found — mode 1 (LUT CDF) will be skipped" +fi +if [ -f "${QUANTILES_FILE}" ]; then + BENCH_EXTRA_ARGS+=(--quantiles-path "${QUANTILES_FILE}") + echo " Using quantiles for mode 2: ${QUANTILES_FILE}" +else + echo " WARNING: ${QUANTILES_FILE} not found — mode 2 (Quantile) will be skipped" +fi + +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 \ - --seq-lens 4096 \ + --seq-lens 32768 \ --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 2 \ + --num-kv-heads 8 \ --distributions bucket_uniform normal \ --histogram \ - --filter-kernels sglang_m0 sglang_m1 sglang_m2 sglang_m3 sglang_m4 sglang_m6 sglang_m7 sglang_m8 \ + "${BENCH_EXTRA_ARGS[@]}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --filter-kernels naive sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 \ --repeat 20 \ --output-json "${BENCH_JSON}" \ - 2>&1 | tee "${RUN_DIR}/step2_bench.log" + 2>&1 | tee "${RUN_DIR}/step3_bench.log" -echo ">>> Step 2: Done. Results saved to ${BENCH_JSON}" +echo ">>> Step 3: Done. Results saved to ${BENCH_JSON}" -# ── Step 3: Analyze — comparison plots + tables ─────────────── +# ── Step 4: Analyze — comparison plots + tables ─────────────── echo "" -echo ">>> Step 3: Generating distribution comparison plots + tables" +echo ">>> Step 4: Generating distribution comparison plots + tables" + +# Build optional args for analyze +ANALYZE_EXTRA_ARGS=() +if [ -n "${REAL_HIST_PATH:-}" ]; then + ANALYZE_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") +fi python "${BENCH_DIR}/analyze_topk_distribution.py" \ --bench-json "${BENCH_JSON}" \ - --real-histograms "${REAL_HIST_PATH}" \ + "${ANALYZE_EXTRA_ARGS[@]}" \ --output-dir "${RUN_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step3_analyze.log" + 2>&1 | tee "${RUN_DIR}/step4_analyze.log" -echo ">>> Step 3: Done." +echo ">>> Step 4: Done." # ── Summary ─────────────────────────────────────────────────── echo "" echo "============================================================" echo "Bucket Distribution Profiling Complete" echo " All outputs in: ${RUN_DIR}/" +echo " autotune_results.json — hyperparameter sweep rankings" echo " bench_distribution.json — raw benchmark data" echo " distribution_comparison.png — bucket dist plots" echo " bucket_counts.csv — per-bucket count table" -echo " step{1,2,3}_*.log — pipeline logs" +echo " step{1,2,3,4}_*.log — pipeline logs" echo "============================================================" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index 3dc1bd41..f0938fff 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -6,7 +6,10 @@ # Mode 3 (Power): y = sign(x) * |x|^p # Mode 6 (Asinh): y = asinh(beta * x) # Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) -# Mode 8 (Trunc8): bf16 upper-8-bit bucketing +# Mode 8 (Trunc8): bf16 upper-8-bit bucketing +# Mode 9 (Erf): y = erf(alpha * x) +# Mode 10 (Tanh): y = tanh(alpha * x) +# Mode 11 (Subtract): x - pivot (RadiK-style scatter) # # Four steps: # 1. Calibrate — collect real-data histograms @@ -26,7 +29,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=5 +GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 MEM=0.7 @@ -56,7 +59,7 @@ RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" -echo "Bucket Distribution Profiling (modes 3, 6, 7)" +echo "Bucket Distribution Profiling (modes 3, 6, 7, 8, 9, 10, 11)" echo " Model: ${MODEL_NAME}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" @@ -95,9 +98,10 @@ AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ --batch-size 4 \ - --seq-len 4096 \ - --num-kv-heads 2 \ + --seq-len 32768 \ + --num-kv-heads 8 \ --real-histograms "${REAL_HIST_PATH}" \ + --latency-rerank \ --output-json "${AUTOTUNE_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2_autotune.log" @@ -111,14 +115,14 @@ BENCH_JSON="${RUN_DIR}/bench_distribution.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 \ - --seq-lens 4096 \ + --seq-lens 32768 \ --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 2 \ + --num-kv-heads 8 \ --distributions bucket_uniform normal \ --histogram \ --real-histograms "${REAL_HIST_PATH}" \ --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels sglang_m3 sglang_m6 sglang_m7 sglang_m8 \ + --filter-kernels naive sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" @@ -134,13 +138,12 @@ python "${BENCH_DIR}/analyze_topk_distribution.py" \ --real-histograms "${REAL_HIST_PATH}" \ --output-dir "${RUN_DIR}" \ 2>&1 | tee "${RUN_DIR}/step4_analyze.log" - echo ">>> Step 4: Done." # ── Summary ─────────────────────────────────────────────────── echo "" echo "============================================================" -echo "Bucket Distribution Profiling Complete (modes 3, 6, 7)" +echo "Bucket Distribution Profiling Complete (modes 3, 6, 7, 8, 9, 10, 11)" echo " All outputs in: ${RUN_DIR}/" echo " autotune_results.json — hyperparameter sweep rankings" echo " bench_distribution.json — raw benchmark data" diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh index 5a7ed94e..6ac2b9de 100755 --- a/examples/run_topk_benchmark.sh +++ b/examples/run_topk_benchmark.sh @@ -34,7 +34,7 @@ set -euo pipefail # use GPU_ID to set the GPU id you want to use -GPU_ID=5 +GPU_ID=4 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" @@ -118,7 +118,7 @@ else python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ --batch-size 4 \ - --seq-len 4096 \ + --seq-len 32768 \ --num-kv-heads 2 \ ${REAL_HIST_ARGS} \ --output-json "${AUTOTUNE_JSON}" \ @@ -143,7 +143,7 @@ else python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 8 16 32 \ - --seq-lens 2048 4096 8192 16384 \ + --seq-lens 2048 4096 8192 16384 32768 \ --topk-vals "${TOPK_VAL}" \ --num-kv-heads 2 4 \ --distributions normal lognormal uniform \ @@ -263,6 +263,32 @@ else --topk-mapping-mode 7 \ --topk-mapping-power 1.0 + # 3j. SGLang mode 8 (Trunc8) + run_e2e "sglang_mode8_trunc8" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 8 + + # 3k. SGLang mode 9 (Erf) + run_e2e "sglang_mode9_erf" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 9 \ + --topk-mapping-power 1.0 + + # 3l. SGLang mode 10 (Tanh) + run_e2e "sglang_mode10_tanh" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 10 \ + --topk-mapping-power 1.0 + + # 3m. SGLang mode 11 (Subtract) + run_e2e "sglang_mode11_subtract" \ + --vortex-module-name "${ALGO}" \ + --topk-type sglang \ + --topk-mapping-mode 11 + echo "" echo ">>> Step 3: Done. E2E logs saved to ${E2E_DIR}/" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index e04f787c..f1c2a2d7 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -254,8 +254,8 @@ def parse_args(): "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5, 6, 7], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p (default: 0).', + choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract (default: 0).', ) parser.add_argument( diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index 918252cd..2370ca19 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -11,6 +11,9 @@ set -e # 6: Asinh — y = asinh(beta * x) # 7: Log1p — y = sign(x) * log1p(alpha * |x|) # 8: Trunc8 — bf16 upper-8-bit bucketing +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 11: Subtract — x - pivot (RadiK-style scatter) export CUDA_VISIBLE_DEVICES=0 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -20,11 +23,6 @@ sparse_algos=( "block_sparse_attention" ) -topk_mapping_modes=( - 0 # none - 3 # power - 4 # log -) RESULTS_DIR="results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) @@ -70,7 +68,50 @@ else done fi +# ============================================================ +# Auto-tune: find best hyperparameters per mode +# Uses topk_profile_histogram kernel on real calibration data +# ============================================================ +REAL_HISTOGRAMS="${CALIBRATION_DIR}/raw_histograms.npy" +if [ -f "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Auto-tuning hyperparameters (real calibration data)" + echo "============================================================" + AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val 30 \ + --batch-size 4 \ + --seq-len 32768 \ + --num-kv-heads 2 \ + --real-histograms "${REAL_HISTOGRAMS}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" + echo "" + # Extract best per-mode hyperparameters from autotune JSON + eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode') + if m in (3, 6, 7, 9, 10): + if m not in best or r['gini'] < best[m]['gini']: + best[m] = r +for m in (3, 6, 7, 9, 10): + print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') +" "${AUTOTUNE_JSON}")" + echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10}" + echo "" +else + echo ">>> WARNING: ${REAL_HISTOGRAMS} not found, using default power=0.5 for all modes" + BEST_POWER_3=0.5 + BEST_POWER_6=0.5 + BEST_POWER_7=0.5 + BEST_POWER_9=0.5 + BEST_POWER_10=0.5 +fi # ============================================================ # Mode 1: LUT CDF with calibrated LUT @@ -111,10 +152,10 @@ for algo in "${sparse_algos[@]}"; do done # ============================================================ -# sglang topk: modes that don't need calibration (0, 3, 4) +# sglang topk: non-parametric modes (0, 4, 8, 11) # ============================================================ for algo in "${sparse_algos[@]}"; do - for topk_mapping_mode in "${topk_mapping_modes[@]}"; do + for topk_mapping_mode in 0 4 8 11; do OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_${topk_mapping_mode}_${TIMESTAMP}.log" echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode ${topk_mapping_mode}" echo ">>> Saving results to ${OUTFILE}" @@ -126,50 +167,102 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode ${topk_mapping_mode} \ - --topk-mapping-power 0.5 \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done done # ============================================================ -# Mode 6: asinh — sweep beta values +# Mode 3: power — autotuned best p # ============================================================ for algo in "${sparse_algos[@]}"; do - for beta in 0.5 1.0 2.0; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${beta}_${TIMESTAMP}.log" - echo ">>> Running mode 6 (asinh) beta=${beta} for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-power ${beta} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_POWER_3}_${TIMESTAMP}.log" + echo ">>> Running mode 3 (power) p=${BEST_POWER_3} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power ${BEST_POWER_3} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done # ============================================================ -# Mode 7: log1p — sweep alpha values +# Mode 6: asinh — autotuned best beta # ============================================================ for algo in "${sparse_algos[@]}"; do - for alpha in 0.5 1.0 2.0; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${alpha}_${TIMESTAMP}.log" - echo ">>> Running mode 7 (log1p) alpha=${alpha} for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-power ${alpha} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_POWER_6}_${TIMESTAMP}.log" + echo ">>> Running mode 6 (asinh) beta=${BEST_POWER_6} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power ${BEST_POWER_6} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Mode 7: log1p — autotuned best alpha +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_POWER_7}_${TIMESTAMP}.log" + echo ">>> Running mode 7 (log1p) alpha=${BEST_POWER_7} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power ${BEST_POWER_7} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Mode 9: erf — autotuned best alpha +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_POWER_9}_${TIMESTAMP}.log" + echo ">>> Running mode 9 (erf) alpha=${BEST_POWER_9} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 9 \ + --topk-mapping-power ${BEST_POWER_9} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Mode 10: tanh — autotuned best alpha +# ============================================================ +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_POWER_10}_${TIMESTAMP}.log" + echo ">>> Running mode 10 (tanh) alpha=${BEST_POWER_10} (autotuned) for ${algo}" + echo ">>> Saving results to ${OUTFILE}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 10 \ + --topk-mapping-power ${BEST_POWER_10} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index b701be28..f1c41fe5 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -11,6 +11,9 @@ set -e # 6: Asinh — y = asinh(beta * x) # 7: Log1p — y = sign(x) * log1p(alpha * |x|) # 8: Trunc8 — bf16 upper-8-bit bucketing +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 11: Subtract — x - pivot (RadiK-style scatter) export CUDA_VISIBLE_DEVICES=5 SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" @@ -38,7 +41,7 @@ AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val 30 \ --batch-size 4 \ - --seq-len 4096 \ + --seq-len 32768 \ --num-kv-heads 2 \ --real-histograms "${REAL_HISTOGRAMS}" \ --output-json "${AUTOTUNE_JSON}" \ @@ -47,72 +50,166 @@ echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" echo "" # ============================================================ -# Step 1: Mode 3 (power) — sweep p values +# Extract best per-mode hyperparameters from autotune JSON +# ============================================================ +eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode') + if m in (3, 6, 7, 9, 10): + if m not in best or r['gini'] < best[m]['gini']: + best[m] = r +for m in (3, 6, 7, 9, 10): + print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') +" "${AUTOTUNE_JSON}")" +echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10}" +echo "" + +# ============================================================ +# Step 1: Mode 3 (power) — autotuned best p +# ============================================================ +echo "============================================================" +echo "Step 1: Mode 3 (power) — p=${BEST_POWER_3} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_POWER_3}_${TIMESTAMP}.log" + echo ">>> Mode 3 (power) p=${BEST_POWER_3} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 3 \ + --topk-mapping-power ${BEST_POWER_3} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 2: Mode 6 (asinh) — autotuned best beta +# ============================================================ +echo "============================================================" +echo "Step 2: Mode 6 (asinh) — beta=${BEST_POWER_6} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_POWER_6}_${TIMESTAMP}.log" + echo ">>> Mode 6 (asinh) beta=${BEST_POWER_6} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 6 \ + --topk-mapping-power ${BEST_POWER_6} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 3: Mode 7 (log1p) — autotuned best alpha +# ============================================================ +echo "============================================================" +echo "Step 3: Mode 7 (log1p) — alpha=${BEST_POWER_7} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_POWER_7}_${TIMESTAMP}.log" + echo ">>> Mode 7 (log1p) alpha=${BEST_POWER_7} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 7 \ + --topk-mapping-power ${BEST_POWER_7} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 4: Mode 8 (trunc8) — fixed parameter +# ============================================================ +echo "============================================================" +echo "Step 4: Mode 8 (trunc8)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_8_${TIMESTAMP}.log" + echo ">>> Mode 8 (trunc8) algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 8 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 5: Mode 9 (erf) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 1: Mode 3 (power) — sweeping p" +echo "Step 5: Mode 9 (erf) — alpha=${BEST_POWER_9} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - for p in 0.1 0.25 0.75 0.9; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${p}_${TIMESTAMP}.log" - echo ">>> Mode 3 (power) p=${p} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 3 \ - --topk-mapping-power ${p} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_POWER_9}_${TIMESTAMP}.log" + echo ">>> Mode 9 (erf) alpha=${BEST_POWER_9} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 9 \ + --topk-mapping-power ${BEST_POWER_9} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done # ============================================================ -# Step 2: Mode 6 (asinh) — sweep beta values +# Step 6: Mode 10 (tanh) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 2: Mode 6 (asinh) — sweeping beta" +echo "Step 6: Mode 10 (tanh) — alpha=${BEST_POWER_10} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - for beta in 0.1 0.5 1.0 2.0 4.0; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${beta}_${TIMESTAMP}.log" - echo ">>> Mode 6 (asinh) beta=${beta} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-power ${beta} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_POWER_10}_${TIMESTAMP}.log" + echo ">>> Mode 10 (tanh) alpha=${BEST_POWER_10} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 10 \ + --topk-mapping-power ${BEST_POWER_10} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done # ============================================================ -# Step 3: Mode 7 (log1p) — sweep alpha values +# Step 7: Mode 11 (subtract) — fixed parameter # ============================================================ echo "============================================================" -echo "Step 3: Mode 7 (log1p) — sweeping alpha" +echo "Step 7: Mode 11 (subtract)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - for alpha in 0.1 0.5 0.75 1.0 2.0 4.0 8.0; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${alpha}_${TIMESTAMP}.log" - echo ">>> Mode 7 (log1p) alpha=${alpha} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-power ${alpha} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" - done + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_11_${TIMESTAMP}.log" + echo ">>> Mode 11 (subtract) algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 11 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" done # ============================================================ @@ -120,9 +217,13 @@ done # ============================================================ echo "" echo "============================================================" -echo "All sweeps complete. Results in ${RESULTS_DIR}/" -echo " Auto-tune: ${AUTOTUNE_JSON}" -echo " Mode 3 (power): p = [0.1, 0.25, 0.75, 0.9]" -echo " Mode 6 (asinh): beta = [0.1, 0.5, 1.0, 2.0, 4.0]" -echo " Mode 7 (log1p): alpha = [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]" +echo "All runs complete. Results in ${RESULTS_DIR}/" +echo " Auto-tune: ${AUTOTUNE_JSON}" +echo " Mode 3 (power): p = ${BEST_POWER_3} (autotuned)" +echo " Mode 6 (asinh): beta = ${BEST_POWER_6} (autotuned)" +echo " Mode 7 (log1p): alpha = ${BEST_POWER_7} (autotuned)" +echo " Mode 8 (trunc8): (fixed)" +echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" +echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" +echo " Mode 11 (subtract): (fixed)" echo "============================================================" diff --git a/setup.py b/setup.py index 99c6529b..649f0a08 100644 --- a/setup.py +++ b/setup.py @@ -27,8 +27,6 @@ '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', '-gencode=arch=compute_90,code=sm_90', - '-gencode=arch=compute_100a,code=sm_100a', - '-gencode=arch=compute_120,code=sm_120' ], }, ), diff --git a/third_party/sglang b/third_party/sglang index 5f51c8ef..0ec12893 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 5f51c8ef485fb45990c8166f439da2ee695c03c1 +Subproject commit 0ec12893c4fc0d6ae1d36d4e0512dc21749c4b4b diff --git a/todo.txt b/todo.txt new file mode 100644 index 00000000..53950c3a --- /dev/null +++ b/todo.txt @@ -0,0 +1,308 @@ +1. +prefill 8k/16k/32k block 16/32/64 block topk (8,16) +qwen series 0.6b, 1.7b, 4b, 8b, 16b, 32b +baselines: +flashinfer-fa2/fa3 flashattention v2/v3 +dense + +NSA: block sparse attention +benchmarking: +flash Sparse Attention +https://github.com/Relaxed-System-Lab/Flash-Sparse-Attention + +https://github.com/mit-han-lab/flash-moba + + +Video generation: VAE fp8 convolution +wan 2.1 vae +1.3B input 480P + +3.17: +(For SOSP 26) +warpper: prefill: +1. ragged paged, warpper +disable_radix_cache +new backend: goal sparse prefill with topk on a new warpper: abandon the previous paged warpper, +apply 1 to the whole prefill sequence + +2. topk +idea: we want to improve the current topk kernel /scr/dataset/yuke/xinrui/new/vortex_torch/csrc/topk_sglang.cu for our project, +we want to map the value in each layer for the topk selection to a new distribution that bucket sort is more efficient on. like to make the values to be more uniform +in each bucket. +the number of the heads has a certain distribution, try to adapt to it. +The mapping function should have a low overhead, or it would damage the end2end efficiency. +Key: first profile the whole process, record the distribution of the value in each layer. You need a profile script for this. save the results. +Then design a novel mapping function(can be easily customize by me), to map the value to a new distribution. Don't change the correctness of the sorting, but more efficient for the bucket sort in the /scr/dataset/yuke/xinrui/new/vortex_torch/csrc/topk_sglang.cu +here is some options: +Option A: Adaptive Bit Selection (2-Pass Min/Max on uint32 Key) + +Core idea: Instead of always extracting the top 8 bits of the fp16 key, extract the 8 most-significant varying bits of the full 32-bit key by finding the actual key range within each segment. + +Algorithm: +Pass 1: Parallel warp-reduction to find min_key and max_key over convert_to_uint32(x) for the segment (~N/1024 iterations per thread, same as current Stage 1) +Compute shift = max(0, 31 - clz(max_key - min_key) - 8) (find the bit position of the 8 most-significant differing bits) + +Pass 2: bin = ((convert_to_uint32(x) - min_key) >> shift) & 0xFF — this uses ALL 32 bits of float precision for binning +Overhead: ~2x data reads for Stage 1 (one extra scan for min/max). Min/max reduction can be done with __shfl_xor_sync followed by a single atomicMin/atomicMax in shared memory — very efficient. + +Pros: +Uses full 32-bit float precision instead of just 8 fp16 bits (up to 2^24 = 16M effective resolution levels instead of 256) +Perfectly adaptive to any data range — no calibration needed + +Pure integer arithmetic in the hot loop (shift + subtract + mask) +Guaranteed monotonic (linear mapping of uint32 keys) +Cons: +2x memory bandwidth for Stage 1 (the min/max pass re-reads all data) +Doesn't guarantee perfectly uniform bin counts (a skewed distribution within [min, max] still skews bins) +Expected quality: Excellent for the real distribution (where the range is narrow but contains many distinct float values that the 8-bit fp16 extraction cannot distinguish). +Implementation sketch: + +// Pass 1: find min/max key via parallel reduction +__shared__ uint32_t s_min_key, s_max_key; +if (tx == 0) { s_min_key = 0xFFFFFFFF; s_max_key = 0; } +__syncthreads(); +for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + uint32_t k = convert_to_uint32(input[idx + row_start]); + atomicMin(&s_min_key, k); + atomicMax(&s_max_key, k); +} +__syncthreads(); +uint32_t range = s_max_key - s_min_key; +int shift = max(0, 31 - __clz(range | 1) - 7); // 8 MSBs of range + +// Pass 2: histogram with adaptive bins +for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + uint32_t k = convert_to_uint32(input[idx + row_start]); + uint8_t bin = ((k - s_min_key) >> shift) & 0xFF; + atomicAdd(&s_histogram[bin], 1); +} + + + +Option B: Strided Sampling + Approximate CDF Equalization +Core idea: Read a sparse sample (every S-th element, e.g. S=32) to build an approximate histogram, compute a CDF-equalized LUT from it in shared memory, then apply the LUT during the full scan. +Algorithm: +Sampling pass: Read every 32nd element, build approximate 256-bin histogram in shared memory +In-block LUT construction: 256 threads compute prefix sum -> CDF -> equalized LUT (8 iterations of parallel prefix sum, same as existing run_cumsum) +Full scan: Apply s_lut[convert_to_uint8(x)] for each element +Overhead: ~(1/32 + 1)x = ~1.03x memory reads. LUT construction is ~8 syncthreads (trivial). The hot-loop cost is identical to existing Mode 1 (one shared memory lookup). +Pros: +Near-zero extra bandwidth (only 3% overhead from sampling) +Fully adaptive — no offline calibration needed +Self-tuning: the LUT is computed from the current segment's own data +Hot-loop cost identical to existing LUT mode (1 shared memory read) +Cons: +Approximate: sampling introduces noise in the estimated CDF (especially for small segments) +More complex control flow (3 phases in Stage 1) +For very short segments (<1024 elements), the sample may be too small +Expected quality: Very good. With 1/32 sampling on a segment of 4096+ elements, the CDF estimate has ~128+ samples — sufficient for a good 256-entry LUT. +Implementation sketch: + +// Phase 0: sampled histogram +__shared__ int s_sample_hist[256]; +if (tx < 256) s_sample_hist[tx] = 0; +__syncthreads(); +for (int idx = tx * 32; idx < length; idx += BLOCK_SIZE * 32) { + uint8_t bin = convert_to_uint8(input[idx + row_start]); + atomicAdd(&s_sample_hist[bin], 1); +} +__syncthreads(); + +// Phase 0.5: compute equalized LUT from sampled histogram +__shared__ uint8_t s_eq_lut[256]; +// ... prefix sum on s_sample_hist -> CDF -> s_eq_lut[i] = floor(CDF(i) * 255) + +// Phase 1: full scan with equalized LUT +for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + uint8_t raw_bin = convert_to_uint8(input[idx + row_start]); + uint8_t eq_bin = s_eq_lut[raw_bin]; + atomicAdd(&s_histogram[eq_bin], 1); +} + + +Option C: Online 2-Pass Histogram Equalization +Core idea: The "gold standard" — build the exact histogram first, compute CDF equalization in shared memory, then re-scan with the equalized mapping. +Algorithm: + +Pass 1: Build exact 256-bin histogram using convert_to_uint8(x) (same as current code) +In-block: 256 threads compute prefix sum -> CDF -> equalized LUT (lut[i] = floor(CDF(i) * 255)) + +Pass 2: Re-scan ALL elements, apply s_lut[convert_to_uint8(x)], build equalized histogram -> find threshold +Overhead: Exactly 2x data reads for Stage 1. LUT construction is negligible (~8 __syncthreads steps). +Pros: +Optimal histogram equalization — produces a provably uniform distribution (each equalized bin has almost exactly N/256 elements) +No calibration, no approximation — perfect adaptation to any input distribution +Monotonic (CDF is monotonically non-decreasing) + +Cons: +2x memory bandwidth for Stage 1 (dominant cost) +May not be worth it if Stage 2 is already fast (the overhead of the extra pass might exceed the savings from a smaller threshold bin) +Expected quality: Optimal. The threshold bin after equalization will have ~N/256 elements regardless of input distribution. +When this is worthwhile: When the original hot bin contains a very large fraction of elements (e.g., >20% of N), the savings from nearly eliminating Stage 2 easily outweigh the extra read pass. Given real data's Gini=0.809, this is likely the case. + +Option D: Temporal CDF Caching (Self-Calibrating LUT) +Core idea: The score distribution per head is relatively stable across adjacent decoding steps. After each kernel invocation, write the histogram to global memory. At the start of the NEXT invocation, load it and compute the CDF-equalized LUT. +Algorithm: +At kernel launch, load the previous iteration's histogram from a persistent global buffer prev_hist[head_id][256] +Compute equalized LUT in shared memory (prefix sum -> CDF -> LUT) +Stage 1 uses this LUT (same as Mode 1) +After Stage 1's histogram is built, write it to prev_hist[head_id][256] for the next iteration +Overhead: 1 shared memory lookup per element (identical to existing Mode 1). Plus ~256 int32 reads + ~256 int32 writes per kernel launch (negligible). +Pros: +Near-zero per-element overhead (shared memory LUT lookup) +Self-calibrating — no offline calibration step needed +Adapts to distribution changes over time (with 1-step lag) +Builds directly on existing Mode 1 infrastructure +Very low implementation complexity + +Cons: +1-step lag: the LUT is based on the previous step's data (cold start on first iteration) +Requires persistent global memory buffer (~256 * 4 bytes per head) +May produce suboptimal LUT if the distribution changes rapidly between steps +Expected quality: Very good after the first few iterations. Attention score distributions evolve slowly during generation, so the 1-step lag has minimal impact. +Implementation changes: +Add prev_histogram pointer to TopKMappingParams +Add MAPPING_TEMPORAL_CDF = 7 mode +In kernel: load prev histogram -> compute LUT -> use LUT -> write current histogram +Python side: allocate persistent [num_heads, 256] int32 buffer, pass to kernel + +Option E: Adaptive Exponent-Mantissa Bit Packing +Core idea: The current 8-bit extraction uses 5 exponent + 2 mantissa bits from fp16. When the actual exponent range is narrow (e.g., only 2-3 distinct exponents), most of those 5 exponent bits are wasted. Dynamically reallocate bits: use fewer for the exponent, more for the mantissa. +Algorithm: +Calibration or per-block scan: Determine exponent range [E_min, E_max] of the scores +Choose bit layout based on range width: +Range 1-2 exponents: 1 exp bit + 6 mantissa bits + 1 sign = 64 bins/exponent (vs 4 currently) → 16x improvement +Range 3-4: 2 exp + 5 mantissa + 1 sign +Range 5-8: 3 exp + 4 mantissa + 1 sign +Range 9-16: 4 exp + 3 mantissa + 1 sign +Wider: original 5 exp + 2 mantissa + 1 sign + +Apply: bin = ((exp - E_min) << mantissa_bits) | (mantissa >> (10 - mantissa_bits)) for positive values (with sign-magnitude ordering) +Overhead: Very low (~5-8 integer instructions per element: extract exponent, subtract base, shift, OR with mantissa). No extra memory reads. No LUT. +Pros + +Extremely low overhead — pure register-level bit manipulation +No extra memory reads or shared memory usage +Monotonic (order-preserving within each exponent, and across exponents) +Up to 16x better bin resolution for narrow distributions +Cons: +Requires knowing E_min/E_max (either calibrated offline, or from a quick per-block reduction) +Not as "perfect" as CDF equalization — distribution within each exponent may still be non-uniform +More complex bit manipulation logic +Expected quality: Very good for the observed real distribution (narrow exponent range → dramatic improvement in bin resolution). Not optimal for arbitrary distributions. +Implementation sketch: + +// Assuming E_min, E_max precomputed and passed in TopKMappingParams +__device__ __forceinline__ uint8_t map_adaptive_bits(float x, int e_min, int e_range) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? ~bits : (bits | 0x8000); + int exp_val = (key >> 10) & 0x1F; // 5-bit exponent + int mantissa = key & 0x3FF; // 10-bit mantissa + + // Determine bit allocation based on e_range + int exp_bits, mant_bits; + if (e_range <= 2) { exp_bits = 1; mant_bits = 6; } + else if (e_range <= 4) { exp_bits = 2; mant_bits = 5; } + else if (e_range <= 8) { exp_bits = 3; mant_bits = 4; } + else { return (uint8_t)(key >> 8); } // fallback + + int sign_bit = (key >> 15) & 1; + int exp_part = min((exp_val - e_min), (1 << exp_bits) - 1); + int mant_part = mantissa >> (10 - mant_bits); + return (uint8_t)((sign_bit << 7) | (exp_part << mant_bits) | mant_part); +} +Use a new file to store these options, and make sure I can switch between these options. +3. Adapt Sparse attentions to the vortex this include: +(1) Naive Sparse Attention /scr/dataset/yuke/xinrui/Sparse-benchmark/Flash-Sparse-Attention/nsa_ref +(2) Flash Sparse Attention /scr/dataset/yuke/xinrui/Sparse-benchmark/Flash-Sparse-Attention/fsa +(3) FlashMoBA /scr/dataset/yuke/xinrui/Sparse-benchmark/flash-moba +Need to implement the whole sparse attention kernel, use their attention backend. Replace forward extend + + +# How to use the custom mapping function: +--- + Writing a Custom Mapping Function + + All mapping logic lives in csrc/topk_mapping.cuh. To add a new mode: + + Step 1: Add to the enum in topk_mapping.cuh: + + enum TopKMappingMode { + MAPPING_NONE = 0, + MAPPING_LUT_CDF = 1, + MAPPING_QUANTILE = 2, + MAPPING_POWER = 3, + MAPPING_LOG = 4, + MAPPING_CUSTOM = 5, // <-- your new mode + }; + + Step 2: Write your __device__ mapping function. It must take a float score and return a uint8_t bin index (0–255). The mapping must be monotonic (order-preserving) to ensure + correctness: + + __device__ __forceinline__ + uint8_t map_custom(float x) { + // Example: sqrt transform + float mapped = copysignf(sqrtf(fabsf(x)), x); + return convert_to_uint8(mapped); + } + + Step 3: Add a case to the dispatcher mapped_convert_to_uint8(): + + __device__ __forceinline__ + uint8_t mapped_convert_to_uint8(float x, const TopKMappingParams& params) { + switch (params.mode) { + // ... existing cases ... + case MAPPING_CUSTOM: + return map_custom(x); + default: + return convert_to_uint8(x); + } + } + + Step 4: Update verify_algo.py to accept the new mode value. In parse_args(), change the choices: + + parser.add_argument("--topk-mapping-mode", type=int, default=0, + choices=[0, 1, 2, 3, 4, 5], # add 5 + ...) + + Step 5: Rebuild and test: + + pip install -e . + python examples/verify_algo.py --topk-type sglang --topk-mapping-mode 5 ... + + Key constraint: Your mapping function only affects Stage 1 (coarse 256-bin histogram). Stage 2 refinement always uses raw float bits via convert_to_uint32(), so the final + top-K selection is always correct regardless of your mapping. The goal is to make the Stage 1 histogram more uniform so fewer elements land in the threshold bin, reducing + Stage 2 work. + + If your custom mapping needs extra parameters (like a tensor or scalar), add them to the TopKMappingParams struct, pass them through topk_output_sglang() host function, update + register.h/register.cc bindings, and read them from ctx in output_func.py. + + 1. csrc/topk_sglang.cu — New CUDA kernel + + - TopKHitRate_Kernel: Stage-1-only kernel with mapping support. Builds 256-bin histogram, writes raw histogram to global memory, runs cumsum to find threshold bin, then + computes stage1_resolved = nblk - items_in_threshold_bin. No Stage 2 needed. + - topk_hit_rate(): Host entry point mirroring topk_output_sglang() for mapping param construction and dtype dispatch. + + 2. csrc/register.h + csrc/register.cc — PyBind11 bindings + + - Added topk_hit_rate declaration and m.def(...) binding with default args for mapping params. + + 3. benchmarks/bench_topk.py — Benchmark integration + + - Added compute_hit_rate_stats() helper for per-segment resolution rate + histogram stats. + - Added --hit-rate CLI flag. When enabled, iterates over available mapping modes (0, 3, 4 always; 1 if LUT provided; 2 if quantiles provided) and prints a comparison table. + - Results stored in config_results["hit_rate"] for JSON output. + + 4. benchmarks/analyze_topk_distribution.py — New visualization script + + - 5 plot functions: bin distribution, heatmap, before/after mapping, summary table, mode comparison. + - Loads from --profile-npz and/or --bench-json. + + 5. End-to-end integration (your request) + + - vortex_torch/indexer/context.py: Added topk_hit_rate_enabled slot, populated from sa.vortex_topk_hit_rate (default False). + - vortex_torch/indexer/output_func.py: After the main topk kernel call, if ctx.topk_hit_rate_enabled is True and topk_type is "sglang", it calls topk_hit_rate() and stores + results in self.last_hit_rate_stats / self.last_hit_rate_histograms. Zero overhead when disabled — just a getattr check that short-circuits. + + To enable during inference, set vortex_topk_hit_rate=True in your SGLang server args. \ No newline at end of file diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 78e2923a..8142fbc6 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -25,6 +25,7 @@ class Context(ContextBase): # misc "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", "topk_mapping_mode", "topk_mapping_power", "topk_mapping_lut", "topk_mapping_quantiles", + "topk_mapping_noscale", "topk_histogram_enabled", # auxilary memory in graph @@ -76,6 +77,7 @@ class Context(ContextBase): topk_mapping_power: float #: Power exponent for mapping mode 3. topk_mapping_lut: object #: Optional uint8[256] LUT tensor for mapping mode 1. topk_mapping_quantiles: object #: Optional float32[256] quantiles tensor for mapping mode 2. + topk_mapping_noscale: bool #: Skip auto-range linear scaling, use fp16 bucketing on f(x) (default False). topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). # --- auxiliary --- @@ -156,6 +158,7 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.topk_type = getattr(sa, "vortex_topk_type", "naive") self.topk_mapping_mode = getattr(sa, "vortex_topk_mapping_mode", 0) self.topk_mapping_power = getattr(sa, "vortex_topk_mapping_power", 0.5) + self.topk_mapping_noscale = getattr(sa, "vortex_topk_mapping_noscale", False) self.topk_histogram_enabled = getattr(sa, "vortex_topk_histogram", False) device = getattr(model_runner, "device", "cpu") diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index e4208dc8..53e97173 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -249,6 +249,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_power = getattr(ctx, 'topk_mapping_power', 0.5) mapping_lut = getattr(ctx, 'topk_mapping_lut', None) mapping_quantiles = getattr(ctx, 'topk_mapping_quantiles', None) + mapping_noscale = getattr(ctx, 'topk_mapping_noscale', False) # UNSET sentinel is not a valid torch.Tensor — coerce to None if mapping_lut is UNSET: mapping_lut = None @@ -269,6 +270,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, ) else: # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) @@ -306,6 +308,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, ) # Accumulate histograms for offline calibration _calibration_histograms.append(self.last_histograms.cpu().clone()) From 15f1d03578cfa845e322a0ae6b85498e559b19ca Mon Sep 17 00:00:00 2001 From: UED Date: Thu, 2 Apr 2026 04:12:19 +0000 Subject: [PATCH 15/24] enhance TopK mapping with adaptive tail-window mode; modify example scripts to reflect changes in histogram calibration and TopK mapping parameters. --- CLAUDE.md | 172 ------- csrc/register.cc | 3 +- csrc/register.h | 3 +- csrc/topk_mapping.cuh | 33 +- csrc/topk_sglang.cu | 212 ++++++++- csrc/topk_slgang_ori.cu | 546 ---------------------- examples/run_distribution_analysis.sh | 2 +- examples/run_distribution_analysis_new.sh | 2 +- examples/verify_algo.py | 6 +- examples/verify_algo_topk_mapping_new.sh | 25 +- setup.py | 3 + todo.txt | 308 ------------ vortex_torch/indexer/output_func.py | 1 + 13 files changed, 271 insertions(+), 1045 deletions(-) delete mode 100644 CLAUDE.md delete mode 100644 csrc/topk_slgang_ori.cu delete mode 100644 todo.txt diff --git a/CLAUDE.md b/CLAUDE.md deleted file mode 100644 index 585a246f..00000000 --- a/CLAUDE.md +++ /dev/null @@ -1,172 +0,0 @@ -# CLAUDE.md - -This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. - -## Project Overview - -Vortex is a lightweight, modular framework for building custom sparse attention algorithms for LLM inference. It provides a PyTorch-like frontend that abstracts away batching, caching, and paged attention, running on optimized backends (FlashInfer, CUDA Graph) via SGLang integration. - -## Build & Install - -```bash -# Clone with submodules -git clone -b v1 --recursive - -# Install SGLang dependency (custom fork in third_party/, supports v0.4.9) -cd third_party/sglang && bash install.sh && cd ../../ - -# Install Vortex (editable mode, compiles CUDA extensions for SM_86/SM_89/SM_90) -pip install -e . -``` - -Requires Python >=3.10, torch>=2.7, lighteval[math]==0.12.2. CUDA extensions (`vortex_torch_C`) are built from `csrc/` (register.cc, utils_sglang.cu, topk.cu, topk_sglang.cu). - -## Testing & Verification - -There is no formal test suite (no pytest). Verification is done by running algorithms against SGLang reference output and comparing accuracy on math benchmarks. - -```bash -# Single algorithm verification (from examples/ directory) -python examples/verify_algo.py --trials 2 --topk-val 30 --vortex-module-name block_sparse_attention - -# Full options -python examples/verify_algo.py \ - --trials 8 --topk-val 30 \ - --vortex-module-name block_sparse_attention \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type naive \ - --mem 0.7 - -# Batch test (outputs timestamped logs to examples/results/) -bash examples/verify_algo.sh - -# AIM24 benchmark verification -python examples/verify_aim24.py -``` - -Available `--topk-type` values: `naive` (CUB-based), `sglang` (SGLang-integrated kernel). - -## AI-Powered Algorithm Generation - -```bash -# Generate new sparse attention algorithms via OpenHands (requires LLM_API_KEY env var) -python openhands_gen.py -``` - -Note: Some auto-generated operators may not be fully optimized. Tune `mem_fraction_static` if OOM occurs. - -## Building Documentation - -```bash -make -C docs html -``` - -Uses Sphinx with myst_parser and furo theme. Deployed via GitHub Actions on push to v1 branch. - -## Architecture - -### Core Abstraction: vFlow (`vortex_torch/flow/flow.py`) - -All sparse attention algorithms inherit from `vFlow` and implement three methods: - -- **`forward_indexer(q, o, cache, ctx)`** — Compute sparse page indices from queries. Operates on page-packed tensor view `[S, r, c]`. -- **`forward_cache(cache, loc, ctx)`** — Update/summarize custom cache tensors when a page completes. Operates on batch-major view `[B, r, c]`. -- **`create_cache(page_size, head_dim)`** — Declare custom cache tensor shapes as a dict of `{name: (rows, cols)}`. - -Algorithms are registered via `@register("name")` decorator and instantiated with `build_vflow()`. - -### Operator System (`vortex_torch/indexer/`, `vortex_torch/cache/`) - -Operators (`vOp` subclasses) run in two modes: -- **Profile mode**: Pre-compute output shapes and allocate buffers -- **Execute mode**: Perform actual GPU computation - -Operators are split into two parallel hierarchies: -- **Indexer ops** (`vortex_torch/indexer/`): GeMM, GeMV, topK, reduce (Mean/Max/Min/Sum/L2Norm), softmax, elementwise, transpose, save/load -- **Cache ops** (`vortex_torch/cache/`): GeMM, reduce, elementwise, fill, KV buffer setup - -Both use Triton kernels (in respective `triton_kernels/` subdirectories) for GPU execution. - -### Tensor Format (`vortex_torch/abs/tensor.py`) - -`vTensor` wraps `torch.Tensor` with format metadata (BATCHED, RAGGED, PAGED) to enforce layout consistency across operations. - -### Context System (`vortex_torch/abs/context_base.py`) - -`ContextBase` carries per-step runtime state. Specialized as: -- `Indexer.Context`: Page layout, head config, hardware info -- `Cache.Context`: Page size, total pages, model info - -### Concrete Algorithms (`vortex_torch/flow/algorithms.py`) - -- **BlockSparseAttention**: Centroid-based routing (query avg → GeMV with centroids → topK) -- **GQABlockSparseAttention**: Grouped-query variant with softmax + group aggregation -- **GQAQuestSparseAttention**: Query-envelope matching using per-page max/min bounds - -### Algorithm Registry (`vortex_torch/flow/registry.py`) - -Algorithms are registered via `@register("name")` and looked up with `get(name)`, `has(name)`, `list_keys()`. Factory: `build_vflow(name)` in `loader.py`. - -### SGLang Integration - -Custom SGLang fork lives in `third_party/sglang` (git submodule, "graph" branch). CUDA extensions in `csrc/` provide PyBind11 bindings for `sglang_plan_decode`, `sglang_plan_prefill`, transpose operations (NH↔HN), and top-K output routing. - -## Key Conventions - -- **Tensor shapes**: Query `[B, H_q, D]`, sparse output `[S_sparse, 1, 1]`, cache indexer-view `[S, r, c]`, cache batch-view `[B, r, c]` -- **GeMM semantics**: `GeMM(x, y)` computes `y @ x^T` (note transposition) -- **Standard cache keys**: `"k"` and `"v"` have inner shape `(page_size, head_dim)`; custom caches declared in `create_cache()` -- **Branch**: Main development is on `v1` - -## Workflow Orchestration - -### 1. Plan Node Default -- Enter plan mode for ANY non-trivial task (3+ steps or architectural decisions) -- If something goes sideways, STOP and re-plan immediately - don't keep pushing -- Use plan mode for verification steps, not just building -- Write detailed specs upfront to reduce ambiguity - -### 2. Subagent Strategy -- Use subagents liberally to keep main context window clean -- Offload research, exploration, and parallel analysis to subagents -- For complex problems, throw more compute at it via subagents -- One tack per subagent for focused execution - -### 3. Self-Improvement Loop -- After ANY correction from the user: update `tasks/lessons.md` with the pattern -- Write rules for yourself that prevent the same mistake -- Ruthlessly iterate on these lessons until mistake rate drops -- Review lessons at session start for relevant project - -### 4. Verification Before Done -- Never mark a task complete without proving it works -- Diff behavior between main and your changes when relevant -- Ask yourself: "Would a staff engineer approve this?" -- Run tests, check logs, demonstrate correctness - -### 5. Demand Elegance (Balanced) -- For non-trivial changes: pause and ask "is there a more elegant way?" -- If a fix feels hacky: "Knowing everything I know now, implement the elegant solution" -- Skip this for simple, obvious fixes - don't over-engineer -- Challenge your own work before presenting it - -### 6. Autonomous Bug Fixing -- When given a bug report: just fix it. Don't ask for hand-holding -- Point at logs, errors, failing tests - then resolve them -- Zero context switching required from the user -- Go fix failing CI tests without being told how - -## Task Management - -1. **Plan First**: Write plan to `tasks/todo.md` with checkable items -2. **Verify Plan**: Check in before starting implementation -3. **Track Progress**: Mark items complete as you go -4. **Explain Changes**: High-level summary at each step -5. **Document Results**: Add review section to `tasks/todo.md` -6. **Capture Lessons**: Update `tasks/lessons.md` after corrections - -## Core Principles - -- **Simplicity First**: Make every change as simple as possible. Impact minimal code. -- **No Laziness**: Find root causes. No temporary fixes. Senior developer standards. -- **Minimal Impact**: Changes should only touch what's necessary. Avoid introducing bugs. \ No newline at end of file diff --git a/csrc/register.cc b/csrc/register.cc index 0a3c11ca..af49aecd 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -27,7 +27,8 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false); + py::arg("mapping_noscale") = false, + py::arg("topk_val") = 0); m.def("topk_profile_stage1", &topk_profile_stage1, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), diff --git a/csrc/register.h b/csrc/register.h index 1a8b8207..dae5e825 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -114,7 +114,8 @@ const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false +const bool mapping_noscale = false, +const int64_t topk_val = 0 ); void topk_profile_stage1( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 97bc141d..8dbb8084 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -4,16 +4,24 @@ #include // ============================================================ -// TopK bucket-sort distribution mapping strategies +// TopK bucket-sort Stage-1 remapping strategies // // These transforms remap float scores before Stage 1's 8-bit -// histogram binning, aiming for a more uniform distribution -// across the 256 coarse bins. Stage 2 refinement still uses -// convert_to_uint32() on raw floats, so correctness is preserved. +// histogram binning. The primary goal is to maximize coarse-bin +// resolution in the score region that determines the top-k +// cutoff, thereby: +// - shrinking the Stage-1 threshold bin (fewer collisions) +// - reducing COUNTER_NUM_EQUAL / COUNTER_STAGE2_INPUT +// - reducing the number of Stage-2 refine rounds // -// Modes 3/4/6/7 use a data-adaptive linear mapping to [0,255] -// instead of fp16 bit-pattern bucketing, guaranteeing full -// bucket utilization regardless of value range. +// Stage 2 refinement still uses convert_to_uint32() on raw +// floats, so final ordering correctness is always preserved. +// +// Modes 3/4/6/7/9/10 apply a nonlinear transform then linearly +// map the result to [0,255]. Mode 12 (ADAPTIVE_TAIL_WINDOW) +// directly focuses all 256 bins on the competitive upper tail +// estimated from the top-k ratio, collapsing irrelevant +// low-score mass into bin 0. // ============================================================ enum TopKMappingMode { @@ -29,15 +37,19 @@ enum TopKMappingMode { MAPPING_ERF = 9, // erf(alpha * x) MAPPING_TANH = 10, // tanh(alpha * x) MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing + MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile }; struct TopKMappingParams { int mode; // TopKMappingMode float power_exp; // For MAPPING_POWER (default 0.5) + // For MAPPING_ADAPTIVE_TAIL_WINDOW: tail expansion + // factor rho (default 4.0). tau_low = Q(1 - rho*k/n). const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) + int target_k; // Top-k value; used by MAPPING_ADAPTIVE_TAIL_WINDOW }; // NOTE: convert_to_uint8() must be defined before including this header. @@ -158,6 +170,8 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( return convert_to_uint8_bf16(x); case MAPPING_SUBTRACT: return convert_to_uint8(x - range_min); // range_min repurposed as pivot + case MAPPING_ADAPTIVE_TAIL_WINDOW: + return linear_map_to_uint8(x, range_min, inv_range); default: // MAPPING_NONE return convert_to_uint8(x); } @@ -174,3 +188,8 @@ __device__ __forceinline__ bool needs_auto_range(int mode) { __device__ __forceinline__ bool needs_pivot(int mode) { return (mode == MAPPING_SUBTRACT); } + +// Helper: check if mode is the adaptive tail-window pre-pass +__device__ __forceinline__ bool needs_tail_window(int mode) { + return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); +} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 9213016a..867efbed 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -571,6 +571,113 @@ __device__ void fast_topk_vortex( } } __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) + // and local_max via a sampled quantile estimator. All 256 coarse bins + // are then allocated to [tau_low, local_max]; scores below tau_low + // collapse into bin 0 via linear_map_to_uint8 clamping. + constexpr int MAX_SAMPLES = 1024; + __shared__ float s_samples[MAX_SAMPLES]; + __shared__ int s_sample_count; + + if (tx == 0) s_sample_count = 0; + __syncthreads(); + + // Compute sampling stride so we collect ~MAX_SAMPLES from the segment + const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; + const int sample_stride = max(desired_stride, 1); + + // Each thread samples elements and finds local_max simultaneously + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count, 1); + if (slot < MAX_SAMPLES) { + s_samples[slot] = val; + } + } + + // Reduce local_max across block + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_tw[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_tw[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_tw[0]; + + int nsamp = min(s_sample_count, MAX_SAMPLES); + + // Simple odd-even transposition sort on the sample buffer. + // nsamp <= 1024, and we have 1024 threads, so each thread + // handles one element. O(nsamp) parallel rounds suffice. + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + // Even phase: compare (0,1), (2,3), ... + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + // Odd phase: compare (1,2), (3,4), ... + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + // Estimate tau_low = Q(1 - rho * k / n) + if (tx == 0) { + float rho = mapping.power_exp; // reused as tail expansion factor + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = 1.0f - rho * float(k) / float(length); + frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + // Too few samples or the tail covers everything: full range + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; + } + + // Fallback: if tau_low >= local_max, use full-range linear mapping + if (tau_low >= local_max) { + // Find the actual minimum from sorted samples + tau_low = (nsamp > 0) ? s_samples[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); @@ -991,6 +1098,96 @@ void TopKHistogram_Kernel( } } __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass (histogram kernel variant) + constexpr int MAX_SAMPLES_H = 1024; + __shared__ float s_samples_h[MAX_SAMPLES_H]; + __shared__ int s_sample_count_h; + + if (tx == 0) s_sample_count_h = 0; + __syncthreads(); + + const int desired_stride = (nblk + MAX_SAMPLES_H - 1) / MAX_SAMPLES_H; + const int sample_stride_h = max(desired_stride, 1); + + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride_h; idx < nblk; idx += BLOCK_SIZE * sample_stride_h) { + float val = vortex_to_float(score_blk[idx]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count_h, 1); + if (slot < MAX_SAMPLES_H) s_samples_h[slot] = val; + } + + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_h[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_h[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_h[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_h[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_h[0]; + + int nsamp = min(s_sample_count_h, MAX_SAMPLES_H); + + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples_h[i] > s_samples_h[i + 1]) { + float tmp = s_samples_h[i]; + s_samples_h[i] = s_samples_h[i + 1]; + s_samples_h[i + 1] = tmp; + } + } + __syncthreads(); + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples_h[i] > s_samples_h[i + 1]) { + float tmp = s_samples_h[i]; + s_samples_h[i] = s_samples_h[i + 1]; + s_samples_h[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = mapping.target_k; + float frac = (k > 0 && nblk > 0) ? 1.0f - rho * float(k) / float(nblk) : 0.0f; + frac = fmaxf(frac, 0.0f); + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples_h[lo] * (1.0f - t) + s_samples_h[lo + 1] * t; + } + + if (tau_low >= local_max) { + tau_low = (nsamp > 0) ? s_samples_h[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); @@ -1164,6 +1361,7 @@ void topk_output_sglang( mapping.quantiles = nullptr; mapping.noscale = mapping_noscale; mapping.sample_stride = 1; + mapping.target_k = static_cast(topk_val); if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1233,7 +1431,8 @@ void topk_profile_histogram( const double mapping_power, std::optional mapping_lut, std::optional mapping_quantiles, - const bool mapping_noscale) + const bool mapping_noscale, + const int64_t topk_val) { CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); @@ -1252,6 +1451,7 @@ void topk_profile_histogram( mapping.quantiles = nullptr; mapping.noscale = mapping_noscale; mapping.sample_stride = 1; + mapping.target_k = static_cast(topk_val); if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1305,7 +1505,8 @@ static TopKMappingParams build_mapping_params( std::optional& mapping_lut, std::optional& mapping_quantiles, bool mapping_noscale = false, - int sample_stride = 1) + int sample_stride = 1, + int target_k = 0) { TopKMappingParams mapping{}; mapping.mode = static_cast(mapping_mode); @@ -1314,6 +1515,7 @@ static TopKMappingParams build_mapping_params( mapping.quantiles = nullptr; mapping.noscale = mapping_noscale; mapping.sample_stride = sample_stride; + mapping.target_k = target_k; if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); @@ -1356,7 +1558,8 @@ void topk_profile_stage1( "topk_profile_stage1: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, mapping_noscale); + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); @@ -1428,7 +1631,8 @@ void topk_profile_counters( TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, "counters must be int32"); - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, mapping_noscale); + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); diff --git a/csrc/topk_slgang_ori.cu b/csrc/topk_slgang_ori.cu deleted file mode 100644 index 04a2b73b..00000000 --- a/csrc/topk_slgang_ori.cu +++ /dev/null @@ -1,546 +0,0 @@ -/** - * @NOTE: This file is adapted from - * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py - * We: - * 1. adapt from tilelang to pure cuda - * 2. optimize the performance a little - * 3. fix the potential illegal memory access - */ -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -namespace { - -constexpr int TopK = 2048; -constexpr int kThreadsPerBlock = 1024; - -#ifdef USE_ROCM -// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a -// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. -#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES -constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); -#else -constexpr size_t kSmem = 48 * 1024; // bytes -#endif -#else -// Reduced from 128KB to 32KB to improve occupancy. -// Each radix pass needs at most ~TopK candidates in the threshold bin, -// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. -constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) -#endif - -struct FastTopKParams { - const float* __restrict__ input; // [B, input_stride] - const int32_t* __restrict__ row_starts; // [B] - int32_t* __restrict__ indices; // [B, TopK] - int32_t* __restrict__ lengths; // [B] - int64_t input_stride; -}; - -// when length <= TopK, we can directly write the indices -__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { - const auto tid = threadIdx.x; - for (int i = tid; i < TopK; i += kThreadsPerBlock) { - indice[i] = (i < length) ? i : -1; - } -} - -// keep the first `length` entries, set others to -1 -__device__ void naive_topk_transform( - const float* __restrict__ score, - int32_t length, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - dst_page_table[i] = (i < length) ? src_page_table[i] : -1; - } -} - -// keep the first `length` entries, set others to -1 -__device__ void naive_topk_transform_ragged( - const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { - const auto tid = threadIdx.x; - for (auto i = tid; i < TopK; i += kThreadsPerBlock) { - topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; - } -} - -__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return static_cast(key >> 8); -} - -__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); -} - -__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { - // An optimized topk kernel copied from tilelang kernel - // We assume length > TopK here, or it will crash - int topk = TopK; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; - - auto& s_histogram = s_histogram_buf[0]; - // allocate for two rounds - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; - - const int tx = threadIdx.x; - - // stage 1: 8bit coarse histogram - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(input[idx + row_start]); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = input[idx + row_start]; - const auto bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - /// NOTE: (dark) fuse the histogram computation here - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // stage 2: refine with 8bit radix passes -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - // clip here to prevent overflow - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = input[idx + row_start]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[TopK - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - /// NOTE: (dark) fuse the histogram computation here - s_input_idx[r_idx ^ 1][pos] = idx; - const auto bin = convert_to_uint32(raw_input); - const auto sub_bin = (bin >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); - } - } -} - -__global__ __launch_bounds__(kThreadsPerBlock) // topk - void topk_kernel(const FastTopKParams params) { - const auto& [input, row_starts, indices, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto indice = indices + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_cuda(score, indice, length); - } else { - return fast_topk_cuda_tl(score, indice, row_start, length); - } -} - -__global__ __launch_bounds__(kThreadsPerBlock) // decode - void topk_transform_decode_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride) { - const auto& [input, _1, _2, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = 0; - const auto length = lengths[bid]; - const auto src_page_entry = src_page_table + bid * src_stride; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } -} - -__global__ __launch_bounds__(kThreadsPerBlock) // prefill - void topk_transform_prefill_kernel( - const FastTopKParams params, - int32_t* __restrict__ dst_page_table, - const int32_t* __restrict__ src_page_table, - const int64_t src_stride, - const int32_t* __restrict__ cu_seqlens_q, - const int64_t prefill_bs) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto length = lengths[bid]; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto dst_page_entry = dst_page_table + bid * TopK; - const auto score = input + bid * input_stride; - - /// NOTE: prefill bs is usually small, we can just use a simple loop here - /// We ensure that last cu_seqlens is equal to number of blocks launched - __shared__ const int32_t* s_src_page_entry; - if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { - if (tid < prefill_bs) { - if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { - s_src_page_entry = src_page_table + tid * src_stride; - } - } - } else { - for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { - if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { - s_src_page_entry = src_page_table + i * src_stride; - } - } - } - __syncthreads(); - const auto src_page_entry = s_src_page_entry; - - if (length <= TopK) { - return naive_topk_transform(score, length, dst_page_entry, src_page_entry); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_page_entry[idx_0] = src_page_entry[pos_0]; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_page_entry[idx_1] = src_page_entry[pos_1]; - } -} - -__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv - void topk_transform_prefill_ragged_kernel( - const FastTopKParams params, - int32_t* __restrict__ topk_indices_ragged, - const int32_t* __restrict__ topk_indices_offset) { - const auto& [input, row_starts, _, lengths, input_stride] = params; - const auto bid = static_cast(blockIdx.x); - const auto tid = threadIdx.x; - const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; - const auto length = lengths[bid]; - const auto dst_indices_entry = topk_indices_ragged + bid * TopK; - const auto score = input + bid * input_stride; - const auto offset = topk_indices_offset[bid]; - - if (length <= TopK) { - return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); - } else { - __shared__ int s_indices[TopK]; - fast_topk_cuda_tl(score, s_indices, row_start, length); - // copy src[s_indices] to dst, we manually unroll here - static_assert(TopK % kThreadsPerBlock == 0); - static_assert(TopK / kThreadsPerBlock == 2); - const auto idx_0 = tid; - const auto pos_0 = s_indices[idx_0]; - dst_indices_entry[idx_0] = pos_0 + offset; - const auto idx_1 = tid + kThreadsPerBlock; - const auto pos_1 = s_indices[idx_1]; - dst_indices_entry[idx_1] = pos_1 + offset; - } -} - -auto get_params( - const at::Tensor& score, - const at::Tensor& lengths, - std::optional row_starts_opt = std::nullopt, - std::optional indices_opt = std::nullopt) -> FastTopKParams { - const auto B = score.size(0); - TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); - if (row_starts_opt.has_value()) { - const auto& row_starts = row_starts_opt.value(); - TORCH_CHECK(row_starts.dim() == 1); - TORCH_CHECK(row_starts.size(0) == B); - } - TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); - TORCH_CHECK(lengths.size(0) == B); - int32_t* indices_data_ptr = nullptr; - if (indices_opt.has_value()) { - const auto& indices = indices_opt.value(); - TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); - TORCH_CHECK(indices.size(0) == B); - TORCH_CHECK(indices.size(1) == TopK); - indices_data_ptr = indices.data_ptr(); - } - - return FastTopKParams{ - .input = score.data_ptr(), - .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, - .indices = indices_data_ptr, - .lengths = lengths.data_ptr(), - .input_stride = score.stride(0), - }; -} - -template -void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { -#ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); -} - -} // namespace - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -void fast_topk_interface( - const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(indices); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - CHECK_CUDA(lengths); - const auto params = get_params(score, lengths, row_starts_opt, indices); - const auto B = score.size(0); - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - setup_kernel_smem_once(); - topk_kernel<<>>(params); - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); -} - -void fast_topk_transform_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& dst_page_table, - const at::Tensor& src_page_table, - const at::Tensor& cu_seqlens_q, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(dst_page_table); - CHECK_CUDA(src_page_table); - CHECK_CUDA(cu_seqlens_q); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); - TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); - TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); - const auto prefill_bs = cu_seqlens_q.size(0) - 1; - TORCH_CHECK(dst_page_table.size(0) == B); - TORCH_CHECK(dst_page_table.size(1) == TopK); - TORCH_CHECK(src_page_table.size(0) == prefill_bs); - TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - const auto src_stride = src_page_table.stride(0); - - // dispatch to decode or prefill - // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel - // decode: row_starts_opt is null, invokes the decode kernel - // target verify: row_starts_opt is null, invokes the prefill kernel - const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; - if (is_decode) { - setup_kernel_smem_once(); - topk_transform_decode_kernel<<>>( - params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); - } else { - setup_kernel_smem_once(); - topk_transform_prefill_kernel<<>>( - params, - dst_page_table.data_ptr(), - src_page_table.data_ptr(), - src_stride, - cu_seqlens_q.data_ptr(), - prefill_bs); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); -} - -void fast_topk_transform_ragged_interface( - const at::Tensor& score, - const at::Tensor& lengths, - at::Tensor& topk_indices_ragged, - const at::Tensor& topk_indices_offset, - std::optional row_starts_opt) { - CHECK_CUDA(score); - CHECK_CUDA(lengths); - CHECK_CUDA(topk_indices_ragged); - CHECK_CUDA(topk_indices_offset); - if (row_starts_opt.has_value()) { - CHECK_CUDA(row_starts_opt.value()); - } - - const auto params = get_params(score, lengths, row_starts_opt); - const auto B = score.size(0); - TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); - TORCH_CHECK(topk_indices_offset.dim() == 1); - - TORCH_CHECK(topk_indices_ragged.size(0) == B); - TORCH_CHECK(topk_indices_ragged.size(1) == TopK); - TORCH_CHECK(topk_indices_offset.size(0) == B); - - // launch kernel - const auto stream = at::cuda::getCurrentCUDAStream().stream(); - const auto grid = dim3{static_cast(B)}; - const auto block = dim3{kThreadsPerBlock}; - - setup_kernel_smem_once(); - topk_transform_prefill_ragged_kernel<<>>( - params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); -} diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 3022dda7..98557c77 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -48,7 +48,7 @@ TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) -REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index f0938fff..623bc824 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -36,7 +36,7 @@ MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) # REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" -REAL_HISTOGRAMS="${SCRIPT_DIR}/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in diff --git a/examples/verify_algo.py b/examples/verify_algo.py index f1c2a2d7..fb3e8437 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -254,15 +254,15 @@ def parse_args(): "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract (default: 0).', + choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window (default: 0).', ) parser.add_argument( "--topk-mapping-power", type=float, default=0.5, - help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 7 asinh), alpha (mode 8 log1p). Default: 0.5.', + help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6 asinh), alpha (mode 7 log1p), rho tail expansion (mode 12). Default: 0.5.', ) parser.add_argument( diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index f1c41fe5..5c5d6cf3 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -24,7 +24,7 @@ sparse_algos=( ) # Path to real-data histograms from calibration (for auto-tuning) -REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="" RESULTS_DIR="results" mkdir -p "${RESULTS_DIR}" @@ -212,6 +212,28 @@ for algo in "${sparse_algos[@]}"; do 2>&1 | tee "${OUTFILE}" done +# ============================================================ +# Step 8: Mode 12 (adaptive_tail_window), rho=4.0 +# ============================================================ +echo "" +echo "============================================================" +echo "Step 8: Mode 12 (adaptive_tail_window), rho=4.0" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_12_${TIMESTAMP}.log" + echo ">>> Mode 12 (adaptive_tail_window) algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val 30 \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 12 \ + --topk-mapping-power 4.0 \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + # ============================================================ # Summary # ============================================================ @@ -226,4 +248,5 @@ echo " Mode 8 (trunc8): (fixed)" echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" echo " Mode 11 (subtract): (fixed)" +echo " Mode 12 (tail_win): rho = 4.0" echo "============================================================" diff --git a/setup.py b/setup.py index 649f0a08..9c2186b9 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,9 @@ '-gencode=arch=compute_86,code=sm_86', '-gencode=arch=compute_89,code=sm_89', '-gencode=arch=compute_90,code=sm_90', + '-gencode=arch=compute_100a,code=sm_100a', + '-gencode=arch=compute_120,code=sm_120' + ], }, ), diff --git a/todo.txt b/todo.txt deleted file mode 100644 index 53950c3a..00000000 --- a/todo.txt +++ /dev/null @@ -1,308 +0,0 @@ -1. -prefill 8k/16k/32k block 16/32/64 block topk (8,16) -qwen series 0.6b, 1.7b, 4b, 8b, 16b, 32b -baselines: -flashinfer-fa2/fa3 flashattention v2/v3 -dense - -NSA: block sparse attention -benchmarking: -flash Sparse Attention -https://github.com/Relaxed-System-Lab/Flash-Sparse-Attention - -https://github.com/mit-han-lab/flash-moba - - -Video generation: VAE fp8 convolution -wan 2.1 vae -1.3B input 480P - -3.17: -(For SOSP 26) -warpper: prefill: -1. ragged paged, warpper -disable_radix_cache -new backend: goal sparse prefill with topk on a new warpper: abandon the previous paged warpper, -apply 1 to the whole prefill sequence - -2. topk -idea: we want to improve the current topk kernel /scr/dataset/yuke/xinrui/new/vortex_torch/csrc/topk_sglang.cu for our project, -we want to map the value in each layer for the topk selection to a new distribution that bucket sort is more efficient on. like to make the values to be more uniform -in each bucket. -the number of the heads has a certain distribution, try to adapt to it. -The mapping function should have a low overhead, or it would damage the end2end efficiency. -Key: first profile the whole process, record the distribution of the value in each layer. You need a profile script for this. save the results. -Then design a novel mapping function(can be easily customize by me), to map the value to a new distribution. Don't change the correctness of the sorting, but more efficient for the bucket sort in the /scr/dataset/yuke/xinrui/new/vortex_torch/csrc/topk_sglang.cu -here is some options: -Option A: Adaptive Bit Selection (2-Pass Min/Max on uint32 Key) - -Core idea: Instead of always extracting the top 8 bits of the fp16 key, extract the 8 most-significant varying bits of the full 32-bit key by finding the actual key range within each segment. - -Algorithm: -Pass 1: Parallel warp-reduction to find min_key and max_key over convert_to_uint32(x) for the segment (~N/1024 iterations per thread, same as current Stage 1) -Compute shift = max(0, 31 - clz(max_key - min_key) - 8) (find the bit position of the 8 most-significant differing bits) - -Pass 2: bin = ((convert_to_uint32(x) - min_key) >> shift) & 0xFF — this uses ALL 32 bits of float precision for binning -Overhead: ~2x data reads for Stage 1 (one extra scan for min/max). Min/max reduction can be done with __shfl_xor_sync followed by a single atomicMin/atomicMax in shared memory — very efficient. - -Pros: -Uses full 32-bit float precision instead of just 8 fp16 bits (up to 2^24 = 16M effective resolution levels instead of 256) -Perfectly adaptive to any data range — no calibration needed - -Pure integer arithmetic in the hot loop (shift + subtract + mask) -Guaranteed monotonic (linear mapping of uint32 keys) -Cons: -2x memory bandwidth for Stage 1 (the min/max pass re-reads all data) -Doesn't guarantee perfectly uniform bin counts (a skewed distribution within [min, max] still skews bins) -Expected quality: Excellent for the real distribution (where the range is narrow but contains many distinct float values that the 8-bit fp16 extraction cannot distinguish). -Implementation sketch: - -// Pass 1: find min/max key via parallel reduction -__shared__ uint32_t s_min_key, s_max_key; -if (tx == 0) { s_min_key = 0xFFFFFFFF; s_max_key = 0; } -__syncthreads(); -for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - uint32_t k = convert_to_uint32(input[idx + row_start]); - atomicMin(&s_min_key, k); - atomicMax(&s_max_key, k); -} -__syncthreads(); -uint32_t range = s_max_key - s_min_key; -int shift = max(0, 31 - __clz(range | 1) - 7); // 8 MSBs of range - -// Pass 2: histogram with adaptive bins -for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - uint32_t k = convert_to_uint32(input[idx + row_start]); - uint8_t bin = ((k - s_min_key) >> shift) & 0xFF; - atomicAdd(&s_histogram[bin], 1); -} - - - -Option B: Strided Sampling + Approximate CDF Equalization -Core idea: Read a sparse sample (every S-th element, e.g. S=32) to build an approximate histogram, compute a CDF-equalized LUT from it in shared memory, then apply the LUT during the full scan. -Algorithm: -Sampling pass: Read every 32nd element, build approximate 256-bin histogram in shared memory -In-block LUT construction: 256 threads compute prefix sum -> CDF -> equalized LUT (8 iterations of parallel prefix sum, same as existing run_cumsum) -Full scan: Apply s_lut[convert_to_uint8(x)] for each element -Overhead: ~(1/32 + 1)x = ~1.03x memory reads. LUT construction is ~8 syncthreads (trivial). The hot-loop cost is identical to existing Mode 1 (one shared memory lookup). -Pros: -Near-zero extra bandwidth (only 3% overhead from sampling) -Fully adaptive — no offline calibration needed -Self-tuning: the LUT is computed from the current segment's own data -Hot-loop cost identical to existing LUT mode (1 shared memory read) -Cons: -Approximate: sampling introduces noise in the estimated CDF (especially for small segments) -More complex control flow (3 phases in Stage 1) -For very short segments (<1024 elements), the sample may be too small -Expected quality: Very good. With 1/32 sampling on a segment of 4096+ elements, the CDF estimate has ~128+ samples — sufficient for a good 256-entry LUT. -Implementation sketch: - -// Phase 0: sampled histogram -__shared__ int s_sample_hist[256]; -if (tx < 256) s_sample_hist[tx] = 0; -__syncthreads(); -for (int idx = tx * 32; idx < length; idx += BLOCK_SIZE * 32) { - uint8_t bin = convert_to_uint8(input[idx + row_start]); - atomicAdd(&s_sample_hist[bin], 1); -} -__syncthreads(); - -// Phase 0.5: compute equalized LUT from sampled histogram -__shared__ uint8_t s_eq_lut[256]; -// ... prefix sum on s_sample_hist -> CDF -> s_eq_lut[i] = floor(CDF(i) * 255) - -// Phase 1: full scan with equalized LUT -for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - uint8_t raw_bin = convert_to_uint8(input[idx + row_start]); - uint8_t eq_bin = s_eq_lut[raw_bin]; - atomicAdd(&s_histogram[eq_bin], 1); -} - - -Option C: Online 2-Pass Histogram Equalization -Core idea: The "gold standard" — build the exact histogram first, compute CDF equalization in shared memory, then re-scan with the equalized mapping. -Algorithm: - -Pass 1: Build exact 256-bin histogram using convert_to_uint8(x) (same as current code) -In-block: 256 threads compute prefix sum -> CDF -> equalized LUT (lut[i] = floor(CDF(i) * 255)) - -Pass 2: Re-scan ALL elements, apply s_lut[convert_to_uint8(x)], build equalized histogram -> find threshold -Overhead: Exactly 2x data reads for Stage 1. LUT construction is negligible (~8 __syncthreads steps). -Pros: -Optimal histogram equalization — produces a provably uniform distribution (each equalized bin has almost exactly N/256 elements) -No calibration, no approximation — perfect adaptation to any input distribution -Monotonic (CDF is monotonically non-decreasing) - -Cons: -2x memory bandwidth for Stage 1 (dominant cost) -May not be worth it if Stage 2 is already fast (the overhead of the extra pass might exceed the savings from a smaller threshold bin) -Expected quality: Optimal. The threshold bin after equalization will have ~N/256 elements regardless of input distribution. -When this is worthwhile: When the original hot bin contains a very large fraction of elements (e.g., >20% of N), the savings from nearly eliminating Stage 2 easily outweigh the extra read pass. Given real data's Gini=0.809, this is likely the case. - -Option D: Temporal CDF Caching (Self-Calibrating LUT) -Core idea: The score distribution per head is relatively stable across adjacent decoding steps. After each kernel invocation, write the histogram to global memory. At the start of the NEXT invocation, load it and compute the CDF-equalized LUT. -Algorithm: -At kernel launch, load the previous iteration's histogram from a persistent global buffer prev_hist[head_id][256] -Compute equalized LUT in shared memory (prefix sum -> CDF -> LUT) -Stage 1 uses this LUT (same as Mode 1) -After Stage 1's histogram is built, write it to prev_hist[head_id][256] for the next iteration -Overhead: 1 shared memory lookup per element (identical to existing Mode 1). Plus ~256 int32 reads + ~256 int32 writes per kernel launch (negligible). -Pros: -Near-zero per-element overhead (shared memory LUT lookup) -Self-calibrating — no offline calibration step needed -Adapts to distribution changes over time (with 1-step lag) -Builds directly on existing Mode 1 infrastructure -Very low implementation complexity - -Cons: -1-step lag: the LUT is based on the previous step's data (cold start on first iteration) -Requires persistent global memory buffer (~256 * 4 bytes per head) -May produce suboptimal LUT if the distribution changes rapidly between steps -Expected quality: Very good after the first few iterations. Attention score distributions evolve slowly during generation, so the 1-step lag has minimal impact. -Implementation changes: -Add prev_histogram pointer to TopKMappingParams -Add MAPPING_TEMPORAL_CDF = 7 mode -In kernel: load prev histogram -> compute LUT -> use LUT -> write current histogram -Python side: allocate persistent [num_heads, 256] int32 buffer, pass to kernel - -Option E: Adaptive Exponent-Mantissa Bit Packing -Core idea: The current 8-bit extraction uses 5 exponent + 2 mantissa bits from fp16. When the actual exponent range is narrow (e.g., only 2-3 distinct exponents), most of those 5 exponent bits are wasted. Dynamically reallocate bits: use fewer for the exponent, more for the mantissa. -Algorithm: -Calibration or per-block scan: Determine exponent range [E_min, E_max] of the scores -Choose bit layout based on range width: -Range 1-2 exponents: 1 exp bit + 6 mantissa bits + 1 sign = 64 bins/exponent (vs 4 currently) → 16x improvement -Range 3-4: 2 exp + 5 mantissa + 1 sign -Range 5-8: 3 exp + 4 mantissa + 1 sign -Range 9-16: 4 exp + 3 mantissa + 1 sign -Wider: original 5 exp + 2 mantissa + 1 sign - -Apply: bin = ((exp - E_min) << mantissa_bits) | (mantissa >> (10 - mantissa_bits)) for positive values (with sign-magnitude ordering) -Overhead: Very low (~5-8 integer instructions per element: extract exponent, subtract base, shift, OR with mantissa). No extra memory reads. No LUT. -Pros - -Extremely low overhead — pure register-level bit manipulation -No extra memory reads or shared memory usage -Monotonic (order-preserving within each exponent, and across exponents) -Up to 16x better bin resolution for narrow distributions -Cons: -Requires knowing E_min/E_max (either calibrated offline, or from a quick per-block reduction) -Not as "perfect" as CDF equalization — distribution within each exponent may still be non-uniform -More complex bit manipulation logic -Expected quality: Very good for the observed real distribution (narrow exponent range → dramatic improvement in bin resolution). Not optimal for arbitrary distributions. -Implementation sketch: - -// Assuming E_min, E_max precomputed and passed in TopKMappingParams -__device__ __forceinline__ uint8_t map_adaptive_bits(float x, int e_min, int e_range) { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? ~bits : (bits | 0x8000); - int exp_val = (key >> 10) & 0x1F; // 5-bit exponent - int mantissa = key & 0x3FF; // 10-bit mantissa - - // Determine bit allocation based on e_range - int exp_bits, mant_bits; - if (e_range <= 2) { exp_bits = 1; mant_bits = 6; } - else if (e_range <= 4) { exp_bits = 2; mant_bits = 5; } - else if (e_range <= 8) { exp_bits = 3; mant_bits = 4; } - else { return (uint8_t)(key >> 8); } // fallback - - int sign_bit = (key >> 15) & 1; - int exp_part = min((exp_val - e_min), (1 << exp_bits) - 1); - int mant_part = mantissa >> (10 - mant_bits); - return (uint8_t)((sign_bit << 7) | (exp_part << mant_bits) | mant_part); -} -Use a new file to store these options, and make sure I can switch between these options. -3. Adapt Sparse attentions to the vortex this include: -(1) Naive Sparse Attention /scr/dataset/yuke/xinrui/Sparse-benchmark/Flash-Sparse-Attention/nsa_ref -(2) Flash Sparse Attention /scr/dataset/yuke/xinrui/Sparse-benchmark/Flash-Sparse-Attention/fsa -(3) FlashMoBA /scr/dataset/yuke/xinrui/Sparse-benchmark/flash-moba -Need to implement the whole sparse attention kernel, use their attention backend. Replace forward extend - - -# How to use the custom mapping function: ---- - Writing a Custom Mapping Function - - All mapping logic lives in csrc/topk_mapping.cuh. To add a new mode: - - Step 1: Add to the enum in topk_mapping.cuh: - - enum TopKMappingMode { - MAPPING_NONE = 0, - MAPPING_LUT_CDF = 1, - MAPPING_QUANTILE = 2, - MAPPING_POWER = 3, - MAPPING_LOG = 4, - MAPPING_CUSTOM = 5, // <-- your new mode - }; - - Step 2: Write your __device__ mapping function. It must take a float score and return a uint8_t bin index (0–255). The mapping must be monotonic (order-preserving) to ensure - correctness: - - __device__ __forceinline__ - uint8_t map_custom(float x) { - // Example: sqrt transform - float mapped = copysignf(sqrtf(fabsf(x)), x); - return convert_to_uint8(mapped); - } - - Step 3: Add a case to the dispatcher mapped_convert_to_uint8(): - - __device__ __forceinline__ - uint8_t mapped_convert_to_uint8(float x, const TopKMappingParams& params) { - switch (params.mode) { - // ... existing cases ... - case MAPPING_CUSTOM: - return map_custom(x); - default: - return convert_to_uint8(x); - } - } - - Step 4: Update verify_algo.py to accept the new mode value. In parse_args(), change the choices: - - parser.add_argument("--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5], # add 5 - ...) - - Step 5: Rebuild and test: - - pip install -e . - python examples/verify_algo.py --topk-type sglang --topk-mapping-mode 5 ... - - Key constraint: Your mapping function only affects Stage 1 (coarse 256-bin histogram). Stage 2 refinement always uses raw float bits via convert_to_uint32(), so the final - top-K selection is always correct regardless of your mapping. The goal is to make the Stage 1 histogram more uniform so fewer elements land in the threshold bin, reducing - Stage 2 work. - - If your custom mapping needs extra parameters (like a tensor or scalar), add them to the TopKMappingParams struct, pass them through topk_output_sglang() host function, update - register.h/register.cc bindings, and read them from ctx in output_func.py. - - 1. csrc/topk_sglang.cu — New CUDA kernel - - - TopKHitRate_Kernel: Stage-1-only kernel with mapping support. Builds 256-bin histogram, writes raw histogram to global memory, runs cumsum to find threshold bin, then - computes stage1_resolved = nblk - items_in_threshold_bin. No Stage 2 needed. - - topk_hit_rate(): Host entry point mirroring topk_output_sglang() for mapping param construction and dtype dispatch. - - 2. csrc/register.h + csrc/register.cc — PyBind11 bindings - - - Added topk_hit_rate declaration and m.def(...) binding with default args for mapping params. - - 3. benchmarks/bench_topk.py — Benchmark integration - - - Added compute_hit_rate_stats() helper for per-segment resolution rate + histogram stats. - - Added --hit-rate CLI flag. When enabled, iterates over available mapping modes (0, 3, 4 always; 1 if LUT provided; 2 if quantiles provided) and prints a comparison table. - - Results stored in config_results["hit_rate"] for JSON output. - - 4. benchmarks/analyze_topk_distribution.py — New visualization script - - - 5 plot functions: bin distribution, heatmap, before/after mapping, summary table, mode comparison. - - Loads from --profile-npz and/or --bench-json. - - 5. End-to-end integration (your request) - - - vortex_torch/indexer/context.py: Added topk_hit_rate_enabled slot, populated from sa.vortex_topk_hit_rate (default False). - - vortex_torch/indexer/output_func.py: After the main topk kernel call, if ctx.topk_hit_rate_enabled is True and topk_type is "sglang", it calls topk_hit_rate() and stores - results in self.last_hit_rate_stats / self.last_hit_rate_histograms. Zero overhead when disabled — just a getattr check that short-circuits. - - To enable during inference, set vortex_topk_hit_rate=True in your SGLang server args. \ No newline at end of file diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 53e97173..9c7e076c 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -309,6 +309,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_lut, mapping_quantiles, mapping_noscale, + ctx.topk_val, ) # Accumulate histograms for offline calibration _calibration_histograms.append(self.last_histograms.cpu().clone()) From 080c253430ec0da6ee88063b971ae68174f46c01 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 7 Apr 2026 06:30:19 +0000 Subject: [PATCH 16/24] Enhance TopK mapping with new modes and original sglang kernel support - Added ExpStretch and TopkWindow modes to analyze_topk_distribution.py and bench_topk.py. - Introduced topk_output_sglang_ori function for original sglang kernel in vortex_torch_C. - Updated autotune and benchmark scripts to include new modes and original kernel. - Modified example scripts to reflect changes in histogram calibration and TopK mapping parameters. --- benchmarks/analyze_topk_distribution.py | 4 + benchmarks/autotune_topk_mapping.py | 4 + benchmarks/bench_topk.py | 101 +++- csrc/register.cc | 6 + csrc/register.h | 13 + csrc/topk_mapping.cuh | 21 +- csrc/topk_sglang.cu | 363 +++++++++++- csrc/topk_slgang_ori.cu | 546 ++++++++++++++++++ examples/run_distribution_analysis.sh | 5 +- examples/run_distribution_analysis_new.sh | 9 +- examples/run_topk_benchmark.sh | 6 +- examples/verify_algo.py | 125 +++- examples/verify_algo.sh | 5 + examples/verify_algo_topk_mapping.sh | 44 +- .../verify_algo_topk_mapping_indexcache.sh | 45 -- examples/verify_algo_topk_mapping_new.sh | 166 +++++- third_party/sglang | 2 +- vortex_torch/indexer/output_func.py | 17 +- 18 files changed, 1342 insertions(+), 140 deletions(-) create mode 100644 csrc/topk_slgang_ori.cu delete mode 100644 examples/verify_algo_topk_mapping_indexcache.sh diff --git a/benchmarks/analyze_topk_distribution.py b/benchmarks/analyze_topk_distribution.py index 00cdf287..5531187e 100644 --- a/benchmarks/analyze_topk_distribution.py +++ b/benchmarks/analyze_topk_distribution.py @@ -43,6 +43,8 @@ 9: "Erf", 10: "Tanh", 11: "Subtract", + 13: "ExpStretch", + 14: "TopkWindow", } MAPPING_MODE_FORMULAS = { @@ -58,6 +60,8 @@ 9: "Erf: erf(alpha*x)", 10: "Tanh: tanh(alpha*x)", 11: "Subtract: x - pivot (RadiK-style)", + 13: "ExpStretch: exp(alpha*x)", + 14: "TopkWindow: k-aware linear windowing", } diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index d95c8399..f04418dc 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -39,6 +39,8 @@ 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), 9: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), 10: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), + 13: ("alpha", [0.5, 1.0, 2.0, 4.0, 8.0]), + 14: ("rho", [2.0, 4.0, 8.0, 16.0]), } BASELINES = { 0: ("none", 0.5), @@ -64,6 +66,8 @@ 9: "erf", 10: "tanh", 11: "subtract", + 13: "exp_stretch", + 14: "topk_window", } diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 675092e5..2fd1e314 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -19,7 +19,7 @@ import torch from vortex_torch_C import ( - topk_output, topk_output_sglang, topk_profile_histogram, + topk_output, topk_output_sglang, topk_output_sglang_ori, topk_profile_histogram, topk_profile_stage1, topk_profile_counters, ) @@ -37,6 +37,8 @@ 9: "Erf", 10: "Tanh", 11: "Subtract", + 13: "ExpStretch", + 14: "TopkWindow", } MAPPING_MODE_FORMULAS = { @@ -52,6 +54,8 @@ 9: "Erf: erf(alpha*x)", 10: "Tanh: tanh(alpha*x)", 11: "Subtract: x - pivot (RadiK-style)", + 13: "ExpStretch: exp(alpha*x)", + 14: "TopkWindow: k-aware linear windowing", } @@ -209,7 +213,7 @@ def _load_autotune_powers(path: str) -> Dict[int, float]: best: Dict[int, dict] = {} for r in data: m = r.get("mode") - if m not in (3, 6, 7, 9, 10): + if m not in (3, 6, 7, 9, 10, 13, 14): continue if has_res_rate: score = r.get("res_rate_mean", 0.0) @@ -229,7 +233,8 @@ def _resolve_mode_power(args, mode: int) -> float: Priority: per-mode CLI flag > autotune JSON > global --mapping-power. """ per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7, - 9: getattr(args, 'mapping_power_9', None), 10: getattr(args, 'mapping_power_10', None)} + 9: getattr(args, 'mapping_power_9', None), 10: getattr(args, 'mapping_power_10', None), + 13: getattr(args, 'mapping_power_13', None), 14: getattr(args, 'mapping_power_14', None)} if mode in per_mode_flag and per_mode_flag[mode] is not None: return per_mode_flag[mode] if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: @@ -285,6 +290,7 @@ def run_benchmark(args) -> List[dict]: # Build kernel list all_kernels = { "naive": "naive", + "sglang_ori": "sglang_ori", "sglang_m0": "sglang_m0", "sglang_scale": "sglang_scale", # mode 3 with p=1.0 (identity + linear auto-range scaling) "sglang_m3": "sglang_m3", @@ -300,6 +306,9 @@ def run_benchmark(args) -> List[dict]: "sglang_m10": "sglang_m10", "sglang_m10_noscale": "sglang_m10_noscale", "sglang_m11": "sglang_m11", + "sglang_m13": "sglang_m13", + "sglang_m13_noscale": "sglang_m13_noscale", + "sglang_m14": "sglang_m14", } if mapping_lut is not None: all_kernels["sglang_m1"] = "sglang_m1" @@ -409,6 +418,20 @@ def run_benchmark(args) -> List[dict]: pages_per_seg, ) result = bench_kernel(topk_output, call_args, args.warmup, args.repeat) + elif kernel_name == "sglang_ori": + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + ) + result = bench_kernel(topk_output_sglang_ori, call_args, args.warmup, args.repeat) elif kernel_name == "sglang_scale": call_args = ( inputs["x"], @@ -437,7 +460,7 @@ def run_benchmark(args) -> List[dict]: elif mode == 2: extra_kwargs["mapping_quantiles"] = mapping_quantiles - if mode in (3, 6, 7, 9, 10): + if mode in (3, 6, 7, 9, 10, 13, 14): power = _resolve_mode_power(args, mode) else: power = 0.5 @@ -464,6 +487,8 @@ def run_benchmark(args) -> List[dict]: # Build label if kernel_name == "naive": label = "naive" + elif kernel_name == "sglang_ori": + label = "sglang Ori (no remap)" elif kernel_name == "sglang_scale": label = "sglang Scale Only (p=1.0)" else: @@ -471,14 +496,14 @@ def run_benchmark(args) -> List[dict]: m = int(m_str.split("_")[0]) noscale_suffix = " noscale" if kernel_name.endswith("_noscale") else "" mname = MAPPING_MODE_NAMES.get(m, f'm{m}') - if m in (3, 6, 7, 9, 10): - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[m] + if m in (3, 6, 7, 9, 10, 13, 14): + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[m] label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)}){noscale_suffix}" else: label = f"sglang {mname}{noscale_suffix}" - # Sub-phase profiling for sglang kernels - if kernel_name != "naive": + # Sub-phase profiling for sglang kernels (skip ori baseline) + if kernel_name not in ("naive", "sglang_ori"): if kernel_name == "sglang_scale": s1_mode, s1_power = 3, 1.0 s1_lut, s1_q = None, None @@ -487,7 +512,7 @@ def run_benchmark(args) -> List[dict]: s1_mode_str = kernel_name.split("_m")[1] s1_mode = int(s1_mode_str.split("_")[0]) s1_noscale = kernel_name.endswith("_noscale") - if s1_mode in (3, 6, 7, 9, 10): + if s1_mode in (3, 6, 7, 9, 10, 13, 14): s1_power = _resolve_mode_power(args, s1_mode) else: s1_power = 0.5 @@ -581,6 +606,46 @@ def run_benchmark(args) -> List[dict]: 'stage2_input_max': c[:, 5].max().item(), } + # Counter collection for kernels skipped by sub-phase profiling + if kernel_name in ("sglang_ori",) and args.counters: + inputs["sparse_kv_indices"].zero_() + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + counter_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + 0, # mode 0 (no mapping) — matches ori behavior + 0.5, + None, + None, + False, + ) + topk_profile_counters(*counter_args) + torch.cuda.synchronize() + c = counter_buf.float() + result['counters'] = { + 'threshold_bin_mean': c[:, 0].mean().item(), + 'num_above_mean': c[:, 1].mean().item(), + 'num_equal_mean': c[:, 2].mean().item(), + 'remaining_k_mean': c[:, 3].mean().item(), + 'refine_rounds_mean': c[:, 4].mean().item(), + 'stage2_input_mean': c[:, 5].mean().item(), + 'threshold_bin_max': c[:, 0].max().item(), + 'num_above_max': c[:, 1].max().item(), + 'num_equal_max': c[:, 2].max().item(), + 'remaining_k_max': c[:, 3].max().item(), + 'refine_rounds_max': c[:, 4].max().item(), + 'stage2_input_max': c[:, 5].max().item(), + } + kernel_entries.append((label, kernel_name, result)) config_results["kernels"][kernel_name] = result @@ -703,7 +768,7 @@ def run_benchmark(args) -> List[dict]: extra_lut = mapping_lut if mode == 1 else None extra_q = mapping_quantiles if mode == 2 else None - power = _resolve_mode_power(args, mode) if mode in (3, 6, 7, 9, 10) else 0.5 + power = _resolve_mode_power(args, mode) if mode in (3, 6, 7, 9, 10, 13, 14) else 0.5 topk_profile_histogram( hist_inputs["x"], @@ -716,6 +781,8 @@ def run_benchmark(args) -> List[dict]: power, extra_lut, extra_q, + False, # mapping_noscale + topk_val, # needed for mode 12/14 (tail/topk window) ) torch.cuda.synchronize() @@ -725,8 +792,8 @@ def run_benchmark(args) -> List[dict]: mformula = MAPPING_MODE_FORMULAS.get(mode, mname) mode_stats["name"] = mname mode_stats["formula"] = mformula - if mode in (3, 6, 7, 9, 10): - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[mode] + if mode in (3, 6, 7, 9, 10, 13, 14): + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[mode] mode_stats["param"] = f"{pname}={power}" display_name = f"{mname} ({pname}={power})" else: @@ -736,7 +803,7 @@ def run_benchmark(args) -> List[dict]: hist_entries.append((display_name, f"mode {mode:2d}", mode_stats)) # Noscale histogram analysis for parametric transform modes - noscale_modes = [m for m in (3, 6, 7, 9, 10) if m in modes_to_test] + noscale_modes = [m for m in (3, 6, 7, 9, 10, 13) if m in modes_to_test] for mode in noscale_modes: ns_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") power = _resolve_mode_power(args, mode) @@ -758,7 +825,7 @@ def run_benchmark(args) -> List[dict]: ns_stats["raw_counts"] = ns_hists.sum(dim=0).tolist() mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") mformula = MAPPING_MODE_FORMULAS.get(mode, mname) - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha"}[mode] + pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha"}[mode] ns_stats["name"] = f"{mname} noscale" ns_stats["formula"] = mformula ns_stats["param"] = f"{pname}={power}" @@ -831,9 +898,13 @@ def main(): help="Beta for mode 6 asinh (overrides --mapping-power)") parser.add_argument("--mapping-power-7", type=float, default=None, help="Alpha for mode 7 log1p (overrides --mapping-power)") + parser.add_argument("--mapping-power-13", type=float, default=None, + help="Alpha for mode 13 exp_stretch (overrides --mapping-power)") + parser.add_argument("--mapping-power-14", type=float, default=None, + help="Rho for mode 14 topk_window (overrides --mapping-power)") parser.add_argument("--autotune-json", type=str, default=None, help="Path to autotune_results.json — extracts best per-mode hyperparameters " - "(overrides --mapping-power for modes 3/6/7)") + "(overrides --mapping-power for modes 3/6/7/13/14)") parser.add_argument("--lut-path", type=str, default=None, help="Path to .npy uint8[256] LUT for mode=1") parser.add_argument("--quantiles-path", type=str, default=None, help="Path to .npy float32[256] for mode=2") parser.add_argument("--output-json", type=str, default=None, help="Save results to JSON file") diff --git a/csrc/register.cc b/csrc/register.cc index af49aecd..b968e9c5 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -19,6 +19,12 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none(), py::arg("mapping_noscale") = false); + m.def("topk_output_sglang_ori", &topk_output_sglang_ori, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); m.def("topk_profile_histogram", &topk_profile_histogram, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("histograms"), py::arg("eff_batch_size"), diff --git a/csrc/register.h b/csrc/register.h index dae5e825..d4a311ba 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -103,6 +103,19 @@ std::optional mapping_quantiles = std::nullopt, const bool mapping_noscale = false ); +void topk_output_sglang_ori( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +); + void topk_profile_histogram( const at::Tensor& x, const at::Tensor& dense_kv_indptr, diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 8dbb8084..08930083 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -38,6 +38,8 @@ enum TopKMappingMode { MAPPING_TANH = 10, // tanh(alpha * x) MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile + MAPPING_EXP_STRETCH = 13, // exp(alpha * x), concentrates bin resolution on upper tail + MAPPING_TOPK_WINDOW = 14, // k-aware linear windowing: focus bins on [tau_low, max] }; struct TopKMappingParams { @@ -81,6 +83,12 @@ __device__ __forceinline__ float transform_tanh(float x, float alpha) { return tanhf(alpha * x); } +__device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { + float z = alpha * x; + z = fminf(z, 80.0f); // prevent float32 overflow (exp(80) ~ 5.5e34) + return expf(z); +} + // ---- Transform dispatcher (returns float, no bucketing) ---- __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { @@ -91,6 +99,7 @@ __device__ __forceinline__ float apply_transform(float x, const TopKMappingParam case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); case MAPPING_ERF: return transform_erf(x, params.power_exp); case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); default: return x; } } @@ -161,7 +170,8 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( case MAPPING_ASINH: case MAPPING_LOG1P: case MAPPING_ERF: - case MAPPING_TANH: { + case MAPPING_TANH: + case MAPPING_EXP_STRETCH: { float val = apply_transform(x, params); if (params.noscale) return convert_to_uint8(val); return linear_map_to_uint8(val, range_min, inv_range); @@ -171,6 +181,7 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( case MAPPING_SUBTRACT: return convert_to_uint8(x - range_min); // range_min repurposed as pivot case MAPPING_ADAPTIVE_TAIL_WINDOW: + case MAPPING_TOPK_WINDOW: return linear_map_to_uint8(x, range_min, inv_range); default: // MAPPING_NONE return convert_to_uint8(x); @@ -181,7 +192,8 @@ __device__ __forceinline__ uint8_t mapped_convert_to_uint8( __device__ __forceinline__ bool needs_auto_range(int mode) { return (mode == MAPPING_POWER || mode == MAPPING_LOG || mode == MAPPING_ASINH || mode == MAPPING_LOG1P || - mode == MAPPING_ERF || mode == MAPPING_TANH); + mode == MAPPING_ERF || mode == MAPPING_TANH || + mode == MAPPING_EXP_STRETCH); } // Helper: check if a mapping mode needs the pivot pre-pass @@ -193,3 +205,8 @@ __device__ __forceinline__ bool needs_pivot(int mode) { __device__ __forceinline__ bool needs_tail_window(int mode) { return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); } + +// Helper: check if mode is the lightweight topk-window pre-pass +__device__ __forceinline__ bool needs_topk_window(int mode) { + return (mode == MAPPING_TOPK_WINDOW); +} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 867efbed..1d12c309 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -450,7 +450,7 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) return __bfloat162float(x); } -constexpr int VORTEX_MAX_TOPK = 2048; +constexpr int VORTEX_MAX_TOPK = 4096; // Per-segment diagnostic counters written by WriteCounters mode constexpr int COUNTER_THRESHOLD_BIN = 0; // Stage 1 coarse threshold bin id @@ -461,7 +461,178 @@ constexpr int COUNTER_REFINE_ROUNDS = 4; // Stage 2 rounds used (0 = resolved constexpr int COUNTER_STAGE2_INPUT = 5; // candidates entering first Stage 2 refine round constexpr int NUM_TOPK_COUNTERS = 6; -// Templated version of fast_topk_cuda_tl: +// ====================================================================== +// Ori fast path: zero-overhead topk with no mapping infrastructure. +// Adapted from topk_slgang_ori.cu — uses direct convert_to_uint8() +// for Stage 1 binning with no pre-pass, no LUT, no bin cache. +// ====================================================================== +template +__device__ void fast_topk_ori( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: 8-bit coarse histogram (direct convert_to_uint8, no mapping) + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// ====================================================================== +// Templated version of fast_topk_cuda_tl with mapping support: // - ScoreT: float or __nv_bfloat16 // - StopAfterStage1: return after Stage 1 route/filter (for profiling) // - WriteCounters: write diagnostic counters to global memory @@ -678,6 +849,47 @@ __device__ void fast_topk_vortex( s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; } __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Lightweight topk-window pre-pass: compute min/max of raw values, + // then focus all 256 bins on [tau_low, max] where + // tau_low = max - (max - min) * rho * k / length. + // Like mode 12 but uses a simple heuristic instead of quantile estimation. + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = vortex_to_float(input[idx + row_start]); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + __shared__ float s_warp_mins_tw2[32], s_warp_maxs_tw2[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins_tw2[warp_id] = local_min; s_warp_maxs_tw2[warp_id] = local_max; } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins_tw2[tx]; local_max = s_warp_maxs_tw2[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = rho * float(k) / float(length); + frac = fminf(frac, 1.0f); + float tau_low = local_max - (local_max - local_min) * frac; + if (tau_low >= local_max) tau_low = local_min; + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); } else { if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } __syncthreads(); @@ -918,6 +1130,42 @@ void TopKOutput_Kernel( } } +// Ori fast-path wrapper: zero mapping overhead +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Ori_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + // ====================================================================== // Profiling Stage1 kernel: runs pre-pass + hist + cumsum + route/filter, // stops before Stage 2 refinement (for sub-phase timing) @@ -1382,9 +1630,100 @@ void topk_output_sglang( dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + // Fast path for mode 0 (MAPPING_NONE): use ori kernel with zero mapping overhead + if (mapping_mode == MAPPING_NONE) { + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Ori_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Ori_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos); + } else { + TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); + } + } else { + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); + } + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Explicit ori baseline entry point — always uses the ori fast path +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_ori: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( + setup_kernel_smem_once, kSmem>(); + TopKOutput_Ori_Kernel<__nv_bfloat16><<>>( reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), @@ -1392,11 +1731,10 @@ void topk_output_sglang( sparse_kv_indices.data_ptr(), topk_val, reserved_bos, - reserved_eos, - mapping); + reserved_eos); } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( + setup_kernel_smem_once, kSmem>(); + TopKOutput_Ori_Kernel<<>>( x.data_ptr(), dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), @@ -1404,17 +1742,14 @@ void topk_output_sglang( sparse_kv_indices.data_ptr(), topk_val, reserved_bos, - reserved_eos, - mapping); + reserved_eos); } else { - TORCH_CHECK(false, - "topk_output: unsupported dtype ", - x.scalar_type()); + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); } const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); + "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); } // ====================================================================== diff --git a/csrc/topk_slgang_ori.cu b/csrc/topk_slgang_ori.cu new file mode 100644 index 00000000..04a2b73b --- /dev/null +++ b/csrc/topk_slgang_ori.cu @@ -0,0 +1,546 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } +} + +__global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + static_assert(TopK % kThreadsPerBlock == 0); + static_assert(TopK / kThreadsPerBlock == 2); + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } +} + +auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} + +void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); +} diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 98557c77..6806eca5 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -48,8 +48,8 @@ TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) +REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" REAL_HISTOGRAMS="" - # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in @@ -162,9 +162,10 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --num-kv-heads 8 \ --distributions bucket_uniform normal \ --histogram \ + --counters \ "${BENCH_EXTRA_ARGS[@]}" \ --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels naive sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 \ + --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index 623bc824..1f89c0b9 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -35,8 +35,8 @@ TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" # The path to the raw_histograms.npy file (set to skip calibration) -# REAL_HISTOGRAMS="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration/raw_histograms.npy" -REAL_HISTOGRAMS="" +REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" +# REAL_HISTOGRAMS="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in @@ -55,7 +55,7 @@ export CUDA_VISIBLE_DEVICES="${GPU_ID}" RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" +RUN_DIR="${RESULTS_DIR}/dist_analysis_topk${TOPK_VAL}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" @@ -120,9 +120,10 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --num-kv-heads 8 \ --distributions bucket_uniform normal \ --histogram \ + --counters \ --real-histograms "${REAL_HIST_PATH}" \ --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels naive sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 \ + --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh index 6ac2b9de..d57e2f1c 100755 --- a/examples/run_topk_benchmark.sh +++ b/examples/run_topk_benchmark.sh @@ -48,6 +48,7 @@ ALGO="block_sparse_attention" SKIP_CALIBRATE=false SKIP_KERNEL=false SKIP_E2E=true +BENCHMARKS="amc23" # space-separated list, e.g. "amc23 aime24" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do @@ -58,6 +59,7 @@ while [[ $# -gt 0 ]]; do --mem) MEM="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; --skip-calibrate) SKIP_CALIBRATE=true; shift ;; --skip-kernel) SKIP_KERNEL=true; shift ;; --skip-e2e) SKIP_E2E=true; shift ;; @@ -70,7 +72,8 @@ export CUDA_VISIBLE_DEVICES="${GPU_ID}" RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -RUN_DIR="${RESULTS_DIR}/topk_benchmark_${TIMESTAMP}" +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +RUN_DIR="${RESULTS_DIR}/topk_benchmark_${BENCH_LABEL}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" @@ -194,6 +197,7 @@ else --trials "${TRIALS}" \ --topk-val "${TOPK_VAL}" \ --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ --mem "${MEM}" \ "$@" ; } \ 2>&1 | tee "${logfile}" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index fb3e8437..32ff5a39 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -51,6 +51,63 @@ def generate_requests(dataset: Dataset, field_name: str, data_format: str, trial return requests +BENCHMARK_REGISTRY = { + "amc23": { + "type": "jsonl", + "path": "amc23.jsonl", + "prompt_key": "prompt", + "answer_key": "answer", + "question_key": "question", + }, + "aime24": { + "type": "huggingface", + "path": "HuggingFaceH4/aime_2024", + "split": "train", + "field_name": "problem", + "answer_key": "answer", + }, +} + +def _load_benchmark(benchmark_name: str, trials: int, tokenizer=None): + """Load benchmark data and return (prompts, requests) tuple.""" + cfg = BENCHMARK_REGISTRY[benchmark_name] + + if cfg["type"] == "jsonl": + script_dir = os.path.dirname(os.path.abspath(__file__)) + jsonl_path = os.path.join(script_dir, cfg["path"]) + with open(jsonl_path, "r", encoding="utf-8") as f: + requests = [json.loads(line) for line in f] + requests = requests * trials + prompts = [req[cfg["prompt_key"]] for req in requests] + return prompts, requests + + elif cfg["type"] == "huggingface": + dataset = load_dataset(cfg["path"], split=cfg["split"]) + hf_requests = generate_requests(dataset, cfg["field_name"], MATH_QUERY_TEMPLATE) + # Normalize keys: ensure "question" and "answer" exist + for req in hf_requests: + if "question" not in req and cfg["field_name"] in req: + req["question"] = req[cfg["field_name"]] + # Build chat-template prompts if tokenizer is provided + if tokenizer is not None: + texts = [x["conversations"] for x in hf_requests] + prompts = [ + tokenizer.apply_chat_template( + text, tokenize=False, add_generation_prompt=True, enable_thinking=True + ) for text in texts + ] * trials + hf_requests = hf_requests * trials + else: + prompts = [ + MATH_QUERY_TEMPLATE.format(Question=x[cfg["field_name"]]) for x in hf_requests + ] * trials + hf_requests = hf_requests * trials + return prompts, hf_requests + + else: + raise ValueError(f"Unknown benchmark type: {cfg['type']}") + + def verify_algos( trials: int = 2, topk_val: int = 30, @@ -67,6 +124,7 @@ def verify_algos( topk_mapping_quantiles_path: str = None, index_cache_shared_layers: list = None, disable_cuda_graph: bool = False, +benchmark: str = "amc23", ): llm = sgl.Engine(model_path=model_name, @@ -90,11 +148,8 @@ def verify_algos( vortex_topk_mapping_quantiles_path=topk_mapping_quantiles_path, vortex_index_cache_shared_layers=index_cache_shared_layers, ) - with open("amc23.jsonl", "r", encoding="utf-8") as f: - requests = [json.loads(line) for line in f] - - requests = requests * trials - prompts = [req["prompt"] for req in requests] + tokenizer = AutoTokenizer.from_pretrained(model_name) if benchmark != "amc23" else None + prompts, requests = _load_benchmark(benchmark, trials, tokenizer=tokenizer) sampling_params = {"temperature": 0.6, "top_p": 0.95, "top_k": 20, "max_new_tokens": 8192} @@ -247,15 +302,15 @@ def parse_args(): "--topk-type", type=str, default="naive", - choices=["naive", "sglang"], - help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang (default: "naive").', + choices=["naive", "sglang", "sglang_ori"], + help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang, "sglang_ori" for original sglang baseline (default: "naive").', ) parser.add_argument( "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window (default: 0).', + choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window, 13=exp_stretch, 14=topk_window (default: 0).', ) parser.add_argument( @@ -287,6 +342,15 @@ def parse_args(): help="Layer IDs that reuse indices from the nearest preceding full layer (skip indexer).", ) + parser.add_argument( + "--benchmark", + type=str, + nargs="+", + default=["amc23"], + help="Benchmark(s) to run. Available: amc23, aime24. " + "Use multiple values to run several benchmarks sequentially (default: amc23).", + ) + return parser.parse_args() if __name__ == "__main__": @@ -298,22 +362,31 @@ def parse_args(): args.index_cache_shared_layers = list(range(2, 28, 2)) # [2,4,6,...,26] args.topk_mapping_mode = 0 - summary = verify_algos( - trials=args.trials, - topk_val=args.topk_val, - page_size=args.page_size, - vortex_module_name=args.vortex_module_name, - model_name=args.model_name, - sparse_attention=not(args.full_attention), - mem=args.mem, - kv_cache_dtype=args.kv_cache_dtype, - topk_type=args.topk_type, - topk_mapping_mode=args.topk_mapping_mode, - topk_mapping_power=args.topk_mapping_power, - topk_mapping_lut_path=args.topk_mapping_lut_path, - topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, - index_cache_shared_layers=args.index_cache_shared_layers, - ) - print(summary) + for bench_name in args.benchmark: + if bench_name not in BENCHMARK_REGISTRY: + print(f"WARNING: Unknown benchmark '{bench_name}', skipping. Available: {list(BENCHMARK_REGISTRY.keys())}") + continue + print(f"\n{'='*60}") + print(f"Benchmark: {bench_name}") + print(f"{'='*60}") + summary = verify_algos( + trials=args.trials, + topk_val=args.topk_val, + page_size=args.page_size, + vortex_module_name=args.vortex_module_name, + model_name=args.model_name, + sparse_attention=not(args.full_attention), + mem=args.mem, + kv_cache_dtype=args.kv_cache_dtype, + topk_type=args.topk_type, + topk_mapping_mode=args.topk_mapping_mode, + topk_mapping_power=args.topk_mapping_power, + topk_mapping_lut_path=args.topk_mapping_lut_path, + topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, + index_cache_shared_layers=args.index_cache_shared_layers, + benchmark=bench_name, + ) + summary["benchmark"] = bench_name + print(summary) exit(0) \ No newline at end of file diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index 3edf9b62..ddcd905e 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -24,3 +24,8 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done + + TORCH_CUDA_ARCH_LIST="12.0" \ + MAX_JOBS=64 \ + pip install -e . --no-build-isolation \ + -Ccmake.args="-DENABLE_BELOW_SM90=OFF" \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index 2370ca19..c0a03c54 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -14,7 +14,18 @@ set -e # 9: Erf — y = erf(alpha * x) # 10: Tanh — y = tanh(alpha * x) # 11: Subtract — x - pivot (RadiK-style scatter) -export CUDA_VISIBLE_DEVICES=0 +GPU_ID=0 +BENCHMARKS="amc23" + +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" @@ -23,13 +34,13 @@ sparse_algos=( "block_sparse_attention" ) -RESULTS_DIR="results" +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +RESULTS_DIR="results/${BENCH_LABEL}" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) # Set this to an existing calibration directory to skip re-running calibration. # It must contain lut.npy and quantiles.npy (output of calibrate_topk.py). -CALIBRATION_DIR="/scr/dataset/yuke/xinrui/new/vortex_torch/examples/calibration" - +CALIBRATION_DIR="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration" # ============================================================ # Baseline: naive topk (mode 0) # ============================================================ @@ -44,6 +55,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type naive \ --topk-mapping-mode 0 \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -167,6 +179,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode ${topk_mapping_mode} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -265,4 +278,25 @@ for algo in "${sparse_algos[@]}"; do --topk-mapping-power ${BEST_POWER_10} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" -done \ No newline at end of file +done + +# ============================================================ +# Counter profiling: collect COUNTER_NUM_EQUAL for all modes +# ============================================================ +echo "" +echo "============================================================" +echo "Counter profiling: COUNTER_NUM_EQUAL per mode (topk=30)" +echo "============================================================" +COUNTER_JSON="${RESULTS_DIR}/counters_${TIMESTAMP}.json" +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 \ + --seq-lens 4096 \ + --topk-vals 30 \ + --num-kv-heads 2 \ + --distributions normal \ + --counters \ + --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 \ + --repeat 5 \ + --output-json "${COUNTER_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" +echo ">>> Counters saved to ${COUNTER_JSON}" \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_indexcache.sh b/examples/verify_algo_topk_mapping_indexcache.sh deleted file mode 100644 index 9002084c..00000000 --- a/examples/verify_algo_topk_mapping_indexcache.sh +++ /dev/null @@ -1,45 +0,0 @@ -#!/usr/bin/env bash -set -e -# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use -export CUDA_VISIBLE_DEVICES=5 - -RESULTS_DIR="results" -mkdir -p "${RESULTS_DIR}" -TIMESTAMP=$(date +%Y%m%d_%H%M%S) - -sparse_algos=( - "block_sparse_attention" -) - -# --- Mode 5: Index Cache (default even-layer pattern) --- -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_mode5_index_cache_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode 5 (index cache)" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 5 \ - --index-cache-shared-layers 2 4 6 8 10 12 14 16 18 20 22 24 26 \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# --- Mode 6: Greedy layer selection --- -# for algo in "${sparse_algos[@]}"; do -# OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_mode6_greedy_${TIMESTAMP}.log" -# echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode 6 (greedy)" -# echo ">>> Saving results to ${OUTFILE}" -# { time python verify_algo.py \ -# --trials 8 \ -# --topk-val 30 \ -# --vortex-module-name "${algo}" \ -# --model-name Qwen/Qwen3-1.7B \ -# --topk-type sglang \ -# --topk-mapping-mode 6 \ -# --mem 0.7 ; } \ -# 2>&1 | tee "${OUTFILE}" -#done diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index 5c5d6cf3..4c96a15b 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -14,19 +14,35 @@ set -e # 9: Erf — y = erf(alpha * x) # 10: Tanh — y = tanh(alpha * x) # 11: Subtract — x - pivot (RadiK-style scatter) -export CUDA_VISIBLE_DEVICES=5 - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +TOPK_VAL=30 +BENCHMARKS="amc23" # space-separated list, e.g. "amc23 aime24" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + sparse_algos=( "block_sparse_attention" ) # Path to real-data histograms from calibration (for auto-tuning) -REAL_HISTOGRAMS="" +REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" -RESULTS_DIR="results" +BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') +RESULTS_DIR="results/topk${TOPK_VAL}_${BENCH_LABEL}" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) @@ -39,7 +55,7 @@ echo "Step 0: Auto-tuning hyperparameters (synthetic data)" echo "============================================================" AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --batch-size 4 \ --seq-len 32768 \ --num-kv-heads 2 \ @@ -58,15 +74,35 @@ data = json.load(open(sys.argv[1])) best = {} for r in data: m = r.get('mode') - if m in (3, 6, 7, 9, 10): + if m in (3, 6, 7, 9, 10, 13, 14): if m not in best or r['gini'] < best[m]['gini']: best[m] = r -for m in (3, 6, 7, 9, 10): +for m in (3, 6, 7, 9, 10, 13, 14): print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') " "${AUTOTUNE_JSON}")" -echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10}" +echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10} mode13=${BEST_POWER_13} mode14=${BEST_POWER_14}" echo "" +# ============================================================ +# Baseline: Original sglang kernel (no remap) +# ============================================================ +echo "============================================================" +echo "Baseline: sglang_ori (no remap)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_ori_${TIMESTAMP}.log" + echo ">>> sglang_ori algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val ${TOPK_VAL} \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang_ori \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + # ============================================================ # Step 1: Mode 3 (power) — autotuned best p # ============================================================ @@ -78,12 +114,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 3 (power) p=${BEST_POWER_3} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 3 \ --topk-mapping-power ${BEST_POWER_3} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -99,12 +136,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 6 (asinh) beta=${BEST_POWER_6} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 6 \ --topk-mapping-power ${BEST_POWER_6} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -120,12 +158,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 7 (log1p) alpha=${BEST_POWER_7} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 7 \ --topk-mapping-power ${BEST_POWER_7} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -141,11 +180,12 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 8 (trunc8) algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 8 \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -161,12 +201,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 9 (erf) alpha=${BEST_POWER_9} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 9 \ --topk-mapping-power ${BEST_POWER_9} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -182,12 +223,13 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 10 (tanh) alpha=${BEST_POWER_10} algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 10 \ --topk-mapping-power ${BEST_POWER_10} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -203,11 +245,12 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 11 (subtract) algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 11 \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -224,16 +267,88 @@ for algo in "${sparse_algos[@]}"; do echo ">>> Mode 12 (adaptive_tail_window) algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val ${TOPK_VAL} \ --vortex-module-name "${algo}" \ --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 12 \ --topk-mapping-power 4.0 \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 9: Mode 13 (exp_stretch) — autotuned best alpha +# ============================================================ +echo "" +echo "============================================================" +echo "Step 9: Mode 13 (exp_stretch) — alpha=${BEST_POWER_13} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_13_alpha${BEST_POWER_13}_${TIMESTAMP}.log" + echo ">>> Mode 13 (exp_stretch) alpha=${BEST_POWER_13} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val ${TOPK_VAL} \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 13 \ + --topk-mapping-power ${BEST_POWER_13} \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 ; } \ + 2>&1 | tee "${OUTFILE}" +done + +# ============================================================ +# Step 10: Mode 14 (topk_window) — autotuned best rho +# ============================================================ +echo "" +echo "============================================================" +echo "Step 10: Mode 14 (topk_window) — rho=${BEST_POWER_14} (autotuned)" +echo "============================================================" +for algo in "${sparse_algos[@]}"; do + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_14_rho${BEST_POWER_14}_${TIMESTAMP}.log" + echo ">>> Mode 14 (topk_window) rho=${BEST_POWER_14} algo=${algo}" + { time python verify_algo.py \ + --trials 8 \ + --topk-val ${TOPK_VAL} \ + --vortex-module-name "${algo}" \ + --model-name Qwen/Qwen3-1.7B \ + --topk-type sglang \ + --topk-mapping-mode 14 \ + --topk-mapping-power ${BEST_POWER_14} \ + --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done +# ============================================================ +# Counter profiling: collect COUNTER_NUM_EQUAL for all modes +# (single extra kernel call per mode, no overhead on accuracy runs) +# ============================================================ +echo "" +echo "============================================================" +echo "Counter profiling: COUNTER_NUM_EQUAL per mode (topk=${TOPK_VAL})" +echo "============================================================" +COUNTER_JSON="${RESULTS_DIR}/counters_${TIMESTAMP}.json" +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --batch-sizes 4 \ + --seq-lens 4096 \ + --topk-vals ${TOPK_VAL} \ + --num-kv-heads 2 \ + --distributions normal \ + --counters \ + --real-histograms "${REAL_HISTOGRAMS}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 sglang_m13 sglang_m14 \ + --mapping-power-13 ${BEST_POWER_13} --mapping-power-14 ${BEST_POWER_14} \ + --repeat 5 \ + --output-json "${COUNTER_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" +echo ">>> Counters saved to ${COUNTER_JSON}" + # ============================================================ # Summary # ============================================================ @@ -241,12 +356,15 @@ echo "" echo "============================================================" echo "All runs complete. Results in ${RESULTS_DIR}/" echo " Auto-tune: ${AUTOTUNE_JSON}" -echo " Mode 3 (power): p = ${BEST_POWER_3} (autotuned)" -echo " Mode 6 (asinh): beta = ${BEST_POWER_6} (autotuned)" -echo " Mode 7 (log1p): alpha = ${BEST_POWER_7} (autotuned)" -echo " Mode 8 (trunc8): (fixed)" -echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" -echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" -echo " Mode 11 (subtract): (fixed)" -echo " Mode 12 (tail_win): rho = 4.0" +echo " Counters: ${COUNTER_JSON}" +echo " Mode 3 (power): p = ${BEST_POWER_3} (autotuned)" +echo " Mode 6 (asinh): beta = ${BEST_POWER_6} (autotuned)" +echo " Mode 7 (log1p): alpha = ${BEST_POWER_7} (autotuned)" +echo " Mode 8 (trunc8): (fixed)" +echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" +echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" +echo " Mode 11 (subtract): (fixed)" +echo " Mode 12 (tail_win): rho = 4.0" +echo " Mode 13 (exp_stretch):alpha = ${BEST_POWER_13} (autotuned)" +echo " Mode 14 (topk_window):rho = ${BEST_POWER_14} (autotuned)" echo "============================================================" diff --git a/third_party/sglang b/third_party/sglang index 0ec12893..47faead5 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 0ec12893c4fc0d6ae1d36d4e0512dc21749c4b4b +Subproject commit 47faead5448b14681ac57fc9a3c6311654fc2b17 diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index 9c7e076c..b50ca74f 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,7 +1,7 @@ import torch from typing import Dict, Callable, List, Optional from ..abs import vOp -from vortex_torch_C import topk_output, topk_output_sglang, topk_profile_histogram +from vortex_torch_C import topk_output, topk_output_sglang, topk_output_sglang_ori, topk_profile_histogram from .context import Context from ..abs import vTensor, FORMAT from ..utils import UNSET @@ -91,6 +91,7 @@ class topK(vOp): FORMAT.RAGGED: { "naive": topk_output, "sglang": topk_output_sglang, + "sglang_ori": topk_output_sglang_ori, }, } @@ -272,6 +273,20 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso mapping_quantiles, mapping_noscale, ) + elif self.topk_type == "sglang_ori": + # topk_output_sglang_ori: same CSR interface, no mapping params + self.impl( + x, + ctx.dense_kv_indptr, + ctx.sparse_kv_indptr, + ctx.dense_kv_indices, + o, + ctx.batch_size * ctx.num_kv_heads, + ctx.topk_val, + ctx.page_reserved_bos, + ctx.page_reserved_eos, + ctx.max_num_pages_per_request, + ) else: # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) self.impl( From e6b73e45490752d8ab9104f9f89feeae0a40a3d8 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 7 Apr 2026 21:20:34 +0000 Subject: [PATCH 17/24] Update TopK mapping and profiling functionalities - Added new file for enhanced profiling capabilities. - Updated to include the new profiling source file. - Modified to expand the sweep grid for parameter tuning. - Refactored to improve handling of hyperparameters and added subprocess profiling for large TopK values. - Enhanced and to support new parameters for profiling. - Updated example scripts to reflect changes in TopK parameters and profiling options. --- benchmarks/autotune_topk_mapping.py | 3 +- benchmarks/bench_topk.py | 351 +++--- csrc/register.cc | 10 +- csrc/register.h | 10 +- csrc/topk_mapping.cuh | 2 +- csrc/topk_sglang.cu | 807 +++----------- csrc/topk_sglang_profile.cu | 1203 +++++++++++++++++++++ examples/run_distribution_analysis.sh | 22 +- examples/run_distribution_analysis_new.sh | 25 +- examples/verify_algo.py | 30 +- examples/verify_algo_topk_mapping.sh | 44 +- examples/verify_algo_topk_mapping_new.sh | 78 +- setup.py | 1 + vortex_torch/indexer/output_func.py | 6 +- 14 files changed, 1663 insertions(+), 929 deletions(-) create mode 100644 csrc/topk_sglang_profile.cu diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index f04418dc..8051c14e 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -34,7 +34,7 @@ SWEEP_GRID = { # (mode, param_name, param_values) - 3: ("power_exp", [0.1, 0.25, 0.75, 0.9]), + 3: ("power_exp", [0.1, 0.25, 0.5, 0.75, 0.9, 2.0, 4.0]), 6: ("beta", [0.1, 0.5, 1.0, 2.0, 4.0]), 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), 9: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), @@ -55,6 +55,7 @@ 7: ("log1p_noscale", [1.0]), 9: ("erf_noscale", [1.0]), 10: ("tanh_noscale", [1.0]), + 13: ("exp_stretch_noscale", [1.0, 4.0]), } MODE_NAMES = { 0: "none", diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 2fd1e314..a913bde5 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -227,19 +227,118 @@ def _load_autotune_powers(path: str) -> Dict[int, float]: return {m: v["param"] for m, v in best.items()} -def _resolve_mode_power(args, mode: int) -> float: +def _resolve_mode_hparam(args, mode: int) -> float: """Return the power/beta/alpha for a parametric mapping mode. Priority: per-mode CLI flag > autotune JSON > global --mapping-power. """ - per_mode_flag = {3: args.mapping_power_3, 6: args.mapping_power_6, 7: args.mapping_power_7, - 9: getattr(args, 'mapping_power_9', None), 10: getattr(args, 'mapping_power_10', None), - 13: getattr(args, 'mapping_power_13', None), 14: getattr(args, 'mapping_power_14', None)} + per_mode_flag = {3: args.mapping_hparam_3, 6: args.mapping_hparam_6, 7: args.mapping_hparam_7, + 9: getattr(args, 'mapping_hparam_9', None), 10: getattr(args, 'mapping_hparam_10', None), + 13: getattr(args, 'mapping_hparam_13', None), 14: getattr(args, 'mapping_hparam_14', None)} if mode in per_mode_flag and per_mode_flag[mode] is not None: return per_mode_flag[mode] if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: return args._autotune_powers[mode] - return args.mapping_power + return args.mapping_hparam + + +def _run_subphase_profiling(subphase_modes, inputs, eff_bs, topk_val, + pages_per_seg, args, mapping_lut, mapping_quantiles): + """Run sub-phase profiling (histogram_only + stage1_full) for each mode. + + For topk <= 512, runs inline. For topk > 512, runs each mode in a + separate subprocess to avoid CUDA shared memory exhaustion from + accumulated kernel template registrations. + """ + import subprocess, sys, tempfile, os + + for kernel_name, s1_mode, s1_power, s1_noscale, result in subphase_modes: + s1_lut = mapping_lut if s1_mode == 1 else None + s1_q = mapping_quantiles if s1_mode == 2 else None + + if topk_val <= 512: + # Inline: run directly in this process + hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + hist_args = ( + inputs["x"], inputs["dense_kv_indptr"], hist_buf, eff_bs, + args.reserved_bos, args.reserved_eos, + s1_mode, s1_power, s1_lut, s1_q, s1_noscale, + ) + hist_result = bench_kernel(topk_profile_histogram, hist_args, args.warmup, args.repeat) + + inputs["sparse_kv_indices"].zero_() + stage1_args = ( + inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + s1_mode, s1_power, s1_lut, s1_q, s1_noscale, + ) + stage1_result = bench_kernel(topk_profile_stage1, stage1_args, args.warmup, args.repeat) + else: + # Subprocess: fresh CUDA context per mode to avoid shared memory exhaustion + script = f""" +import torch, json, sys +sys.path.insert(0, '{os.path.dirname(os.path.abspath(__file__))}') +from vortex_torch_C import topk_profile_histogram, topk_profile_stage1 +from bench_topk import make_topk_inputs, bench_kernel + +inputs = make_topk_inputs( + batch_size={inputs['x'].shape[0] // (eff_bs // (inputs['x'].shape[0] if eff_bs == inputs['x'].shape[0] else 1)) if False else 1}, + num_kv_heads=1, seq_len={pages_per_seg * 16}, + page_size=16, topk_val={topk_val}, + reserved_bos={args.reserved_bos}, reserved_eos={args.reserved_eos}, + score_dtype=torch.bfloat16, distribution="normal", +) +eff_bs = {eff_bs} +# Recreate inputs with correct eff_bs +inputs = make_topk_inputs( + batch_size={eff_bs // max(1, eff_bs // pages_per_seg) if False else eff_bs}, + num_kv_heads=1, seq_len={pages_per_seg * 16}, + page_size=16, topk_val={topk_val}, + reserved_bos={args.reserved_bos}, reserved_eos={args.reserved_eos}, + score_dtype=torch.bfloat16, distribution="normal", +) +eff_bs = inputs["eff_batch_size"] + +hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") +hist_result = bench_kernel(topk_profile_histogram, + (inputs["x"], inputs["dense_kv_indptr"], hist_buf, eff_bs, + {args.reserved_bos}, {args.reserved_eos}, {s1_mode}, {s1_power}, + None, None, {s1_noscale}), 5, {args.repeat}) + +inputs["sparse_kv_indices"].zero_() +stage1_result = bench_kernel(topk_profile_stage1, + (inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + eff_bs, {topk_val}, {args.reserved_bos}, {args.reserved_eos}, + inputs["num_pages_per_seg"], {s1_mode}, {s1_power}, + None, None, {s1_noscale}), 5, {args.repeat}) + +print(json.dumps({{"hist": hist_result, "stage1": stage1_result}})) +""" + try: + proc = subprocess.run( + [sys.executable, "-c", script], + capture_output=True, text=True, timeout=60, + env={**os.environ, "PYTHONPATH": os.path.dirname(os.path.abspath(__file__)) + "/.."}) + if proc.returncode == 0: + data = json.loads(proc.stdout.strip().split("\n")[-1]) + hist_result = data["hist"] + stage1_result = data["stage1"] + else: + # Subprocess failed — skip sub-phase for this mode + continue + except Exception: + continue + + result['histogram_only_mean_ms'] = hist_result['mean_ms'] + result['histogram_only_median_ms'] = hist_result['median_ms'] + result['stage1_full_mean_ms'] = stage1_result['mean_ms'] + result['stage1_full_median_ms'] = stage1_result['median_ms'] + result['route_overhead_mean_ms'] = stage1_result['mean_ms'] - hist_result['mean_ms'] + result['route_overhead_median_ms'] = stage1_result['median_ms'] - hist_result['median_ms'] + result['stage2_refine_mean_ms'] = result['mean_ms'] - stage1_result['mean_ms'] + result['stage2_refine_median_ms'] = result['median_ms'] - stage1_result['median_ms'] def run_benchmark(args) -> List[dict]: @@ -275,6 +374,7 @@ def run_benchmark(args) -> List[dict]: print(f"TopK Kernel Benchmark Results") print(f"GPU: {gpu_name} | SM count: {gpu_props.multi_processor_count}") print(f"Score dtype: {args.score_dtype} | Warmup: {args.warmup} | Repeat: {args.repeat}") + print(f"Radix bits: {args.radix_bits} ({1 << args.radix_bits} bins) | Sample stride: {args.sample_stride}") print("=" * 90) # Load optional LUT / quantiles @@ -430,6 +530,7 @@ def run_benchmark(args) -> List[dict]: args.reserved_bos, args.reserved_eos, pages_per_seg, + args.radix_bits, ) result = bench_kernel(topk_output_sglang_ori, call_args, args.warmup, args.repeat) elif kernel_name == "sglang_scale": @@ -448,6 +549,9 @@ def run_benchmark(args) -> List[dict]: 1.0, # p=1.0 → identity None, None, + False, + args.sample_stride, + args.radix_bits, ) result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) else: @@ -461,7 +565,7 @@ def run_benchmark(args) -> List[dict]: extra_kwargs["mapping_quantiles"] = mapping_quantiles if mode in (3, 6, 7, 9, 10, 13, 14): - power = _resolve_mode_power(args, mode) + power = _resolve_mode_hparam(args, mode) else: power = 0.5 @@ -481,6 +585,8 @@ def run_benchmark(args) -> List[dict]: extra_kwargs.get("mapping_lut", None), extra_kwargs.get("mapping_quantiles", None), is_noscale, + args.sample_stride, + args.radix_bits, ) result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) @@ -498,116 +604,23 @@ def run_benchmark(args) -> List[dict]: mname = MAPPING_MODE_NAMES.get(m, f'm{m}') if m in (3, 6, 7, 9, 10, 13, 14): pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[m] - label = f"sglang {mname} ({pname}={_resolve_mode_power(args, m)}){noscale_suffix}" + label = f"sglang {mname} ({pname}={_resolve_mode_hparam(args, m)}){noscale_suffix}" else: label = f"sglang {mname}{noscale_suffix}" - # Sub-phase profiling for sglang kernels (skip ori baseline) - if kernel_name not in ("naive", "sglang_ori"): - if kernel_name == "sglang_scale": - s1_mode, s1_power = 3, 1.0 - s1_lut, s1_q = None, None - s1_noscale = False + # Counter collection (runs separately from sub-phase profiling) + if kernel_name not in ("naive",) and args.counters: + if kernel_name in ("sglang_ori",): + c_mode, c_power, c_lut, c_q, c_noscale = 0, 0.5, None, None, False + elif kernel_name == "sglang_scale": + c_mode, c_power, c_lut, c_q, c_noscale = 3, 1.0, None, None, False else: - s1_mode_str = kernel_name.split("_m")[1] - s1_mode = int(s1_mode_str.split("_")[0]) - s1_noscale = kernel_name.endswith("_noscale") - if s1_mode in (3, 6, 7, 9, 10, 13, 14): - s1_power = _resolve_mode_power(args, s1_mode) - else: - s1_power = 0.5 - s1_lut = mapping_lut if s1_mode == 1 else None - s1_q = mapping_quantiles if s1_mode == 2 else None - - # Histogram only: pre-pass + histogram build - hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") - hist_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - hist_buf, - eff_bs, - args.reserved_bos, - args.reserved_eos, - s1_mode, - s1_power, - s1_lut, - s1_q, - s1_noscale, - ) - hist_result = bench_kernel(topk_profile_histogram, hist_args, args.warmup, args.repeat) - - # Stage1 full: pre-pass + hist + cumsum + route/filter - inputs["sparse_kv_indices"].zero_() - stage1_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - s1_mode, - s1_power, - s1_lut, - s1_q, - s1_noscale, - ) - stage1_result = bench_kernel(topk_profile_stage1, stage1_args, args.warmup, args.repeat) - - result['histogram_only_mean_ms'] = hist_result['mean_ms'] - result['histogram_only_median_ms'] = hist_result['median_ms'] - result['stage1_full_mean_ms'] = stage1_result['mean_ms'] - result['stage1_full_median_ms'] = stage1_result['median_ms'] - result['route_overhead_mean_ms'] = stage1_result['mean_ms'] - hist_result['mean_ms'] - result['route_overhead_median_ms'] = stage1_result['median_ms'] - hist_result['median_ms'] - result['stage2_refine_mean_ms'] = result['mean_ms'] - stage1_result['mean_ms'] - result['stage2_refine_median_ms'] = result['median_ms'] - stage1_result['median_ms'] - - # Optional counter collection - if args.counters: - inputs["sparse_kv_indices"].zero_() - counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") - counter_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - counter_buf, - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - s1_mode, - s1_power, - s1_lut, - s1_q, - s1_noscale, - ) - topk_profile_counters(*counter_args) - torch.cuda.synchronize() - c = counter_buf.float() - result['counters'] = { - 'threshold_bin_mean': c[:, 0].mean().item(), - 'num_above_mean': c[:, 1].mean().item(), - 'num_equal_mean': c[:, 2].mean().item(), - 'remaining_k_mean': c[:, 3].mean().item(), - 'refine_rounds_mean': c[:, 4].mean().item(), - 'stage2_input_mean': c[:, 5].mean().item(), - 'threshold_bin_max': c[:, 0].max().item(), - 'num_above_max': c[:, 1].max().item(), - 'num_equal_max': c[:, 2].max().item(), - 'remaining_k_max': c[:, 3].max().item(), - 'refine_rounds_max': c[:, 4].max().item(), - 'stage2_input_max': c[:, 5].max().item(), - } - - # Counter collection for kernels skipped by sub-phase profiling - if kernel_name in ("sglang_ori",) and args.counters: + c_mode_str = kernel_name.split("_m")[1] + c_mode = int(c_mode_str.split("_")[0]) + c_noscale = kernel_name.endswith("_noscale") + c_power = _resolve_mode_hparam(args, c_mode) if c_mode in (3,6,7,9,10,13,14) else 0.5 + c_lut = mapping_lut if c_mode == 1 else None + c_q = mapping_quantiles if c_mode == 2 else None inputs["sparse_kv_indices"].zero_() counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") counter_args = ( @@ -622,11 +635,11 @@ def run_benchmark(args) -> List[dict]: args.reserved_bos, args.reserved_eos, pages_per_seg, - 0, # mode 0 (no mapping) — matches ori behavior - 0.5, - None, - None, - False, + c_mode, + c_power, + c_lut, + c_q, + c_noscale, ) topk_profile_counters(*counter_args) torch.cuda.synchronize() @@ -649,6 +662,27 @@ def run_benchmark(args) -> List[dict]: kernel_entries.append((label, kernel_name, result)) config_results["kernels"][kernel_name] = result + # Second pass: sub-phase profiling (histogram_only + stage1_full) + # Run in a subprocess to get a fresh CUDA context, avoiding + # shared memory exhaustion from accumulated kernel registrations. + subphase_modes = [] + for label, kernel_name, result in kernel_entries: + if kernel_name in ("naive", "sglang_ori"): + continue + if kernel_name == "sglang_scale": + s1_mode, s1_power, s1_noscale = 3, 1.0, False + else: + s1_mode_str = kernel_name.split("_m")[1] + s1_mode = int(s1_mode_str.split("_")[0]) + s1_noscale = kernel_name.endswith("_noscale") + s1_power = _resolve_mode_hparam(args, s1_mode) if s1_mode in (3,6,7,9,10,13,14) else 0.5 + subphase_modes.append((kernel_name, s1_mode, s1_power, s1_noscale, result)) + + if subphase_modes: + _run_subphase_profiling( + subphase_modes, inputs, eff_bs, topk_val, + pages_per_seg, args, mapping_lut, mapping_quantiles) + # Print kernel results sorted by mean latency (ascending) kernel_entries.sort(key=lambda e: e[2]['mean_ms']) print(f" --- kernel latency (sorted by mean, ascending) ---") @@ -692,44 +726,13 @@ def run_benchmark(args) -> List[dict]: f"stage2_input={c['stage2_input_mean']:.0f}" ) - # Histogram analysis + # Histogram analysis — uses the SAME inputs as the main benchmark + # so histogram CSV and counters reflect the same data. if args.histogram: - # Build a separate (potentially larger) dataset for histogram profiling - target_pages = (args.histogram_pages - if args.histogram_pages is not None - else _histogram_target_pages(pages_per_seg)) + hist_inputs = inputs + hist_eff_bs = eff_bs current_pages = eff_bs * pages_per_seg - if target_pages > current_pages: - hist_bs = math.ceil(target_pages / (num_kv_heads * pages_per_seg)) - if dist == "real" and real_histogram is not None: - hist_inputs = make_topk_inputs( - batch_size=hist_bs, num_kv_heads=num_kv_heads, - seq_len=seq_len, page_size=args.page_size, - topk_val=topk_val, reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, score_dtype=score_dtype, - distribution="normal", - ) - total_hist_dense = hist_inputs["eff_batch_size"] * hist_inputs["num_pages_per_seg"] - hist_inputs["x"] = _scores_from_histogram(real_histogram, total_hist_dense, device="cuda") - else: - hist_inputs = make_topk_inputs( - batch_size=hist_bs, num_kv_heads=num_kv_heads, - seq_len=seq_len, page_size=args.page_size, - topk_val=topk_val, reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, score_dtype=score_dtype, - distribution=dist, - ) - hist_eff_bs = hist_inputs["eff_batch_size"] - actual_pages = hist_eff_bs * pages_per_seg - print( - f" histogram dataset : {actual_pages} pages " - f"(upscaled from {current_pages} for statistical reliability)" - ) - else: - hist_inputs = inputs - hist_eff_bs = eff_bs - actual_pages = current_pages - print(f" histogram dataset : {actual_pages} pages") + print(f" histogram dataset : {current_pages} pages (same as benchmark)") # Raw unmapped histogram histograms = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") @@ -756,7 +759,7 @@ def run_benchmark(args) -> List[dict]: histograms_results = {} # Per-mode histogram analysis (scaled) - modes_to_test = [0, 3, 4, 6, 7, 8, 9, 10, 11] + modes_to_test = [0, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14] if mapping_lut is not None: modes_to_test.append(1) if mapping_quantiles is not None: @@ -768,7 +771,7 @@ def run_benchmark(args) -> List[dict]: extra_lut = mapping_lut if mode == 1 else None extra_q = mapping_quantiles if mode == 2 else None - power = _resolve_mode_power(args, mode) if mode in (3, 6, 7, 9, 10, 13, 14) else 0.5 + power = _resolve_mode_hparam(args, mode) if mode in (3, 6, 7, 9, 10, 13, 14) else 0.5 topk_profile_histogram( hist_inputs["x"], @@ -806,7 +809,7 @@ def run_benchmark(args) -> List[dict]: noscale_modes = [m for m in (3, 6, 7, 9, 10, 13) if m in modes_to_test] for mode in noscale_modes: ns_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - power = _resolve_mode_power(args, mode) + power = _resolve_mode_hparam(args, mode) topk_profile_histogram( hist_inputs["x"], hist_inputs["dense_kv_indptr"], @@ -890,18 +893,24 @@ def main(): parser.add_argument("--distributions", nargs="+", default=["normal", "lognormal", "uniform"]) parser.add_argument("--warmup", type=int, default=10) parser.add_argument("--repeat", type=int, default=100) - parser.add_argument("--mapping-power", type=float, default=0.5, - help="Global fallback power parameter for parametric modes (default: 0.5)") - parser.add_argument("--mapping-power-3", type=float, default=None, - help="Power exponent p for mode 3 (overrides --mapping-power)") - parser.add_argument("--mapping-power-6", type=float, default=None, - help="Beta for mode 6 asinh (overrides --mapping-power)") - parser.add_argument("--mapping-power-7", type=float, default=None, - help="Alpha for mode 7 log1p (overrides --mapping-power)") - parser.add_argument("--mapping-power-13", type=float, default=None, - help="Alpha for mode 13 exp_stretch (overrides --mapping-power)") - parser.add_argument("--mapping-power-14", type=float, default=None, - help="Rho for mode 14 topk_window (overrides --mapping-power)") + parser.add_argument("--mapping-hparam", "--mapping-power", type=float, default=0.5, + dest="mapping_hparam", + help="Global fallback hyperparameter for parametric modes (default: 0.5)") + parser.add_argument("--mapping-hparam-3", "--mapping-power-3", type=float, default=None, + dest="mapping_hparam_3", + help="Power exponent p for mode 3 (overrides --mapping-hparam)") + parser.add_argument("--mapping-hparam-6", "--mapping-power-6", type=float, default=None, + dest="mapping_hparam_6", + help="Beta for mode 6 asinh (overrides --mapping-hparam)") + parser.add_argument("--mapping-hparam-7", "--mapping-power-7", type=float, default=None, + dest="mapping_hparam_7", + help="Alpha for mode 7 log1p (overrides --mapping-hparam)") + parser.add_argument("--mapping-hparam-13", "--mapping-power-13", type=float, default=None, + dest="mapping_hparam_13", + help="Alpha for mode 13 exp_stretch (overrides --mapping-hparam)") + parser.add_argument("--mapping-hparam-14", "--mapping-power-14", type=float, default=None, + dest="mapping_hparam_14", + help="Rho for mode 14 topk_window (overrides --mapping-hparam)") parser.add_argument("--autotune-json", type=str, default=None, help="Path to autotune_results.json — extracts best per-mode hyperparameters " "(overrides --mapping-power for modes 3/6/7/13/14)") @@ -920,6 +929,12 @@ def main(): parser.add_argument("--counters", action="store_true", help="Collect diagnostic counters (threshold_bin, num_above, num_equal, " "remaining_k, refine_rounds, stage2_input) for each sglang kernel") + parser.add_argument("--sample-stride", type=int, default=1, + help="Pre-pass sampling stride for mapped modes (1=full, 4=1/4, 8=1/8). " + "Higher values reduce pre-pass overhead at cost of bin quality (default: 1)") + parser.add_argument("--radix-bits", type=int, default=8, + help="Stage 1 radix bits for ori/mode-0 kernel: 4=16 bins, 6=64, 8=256, 9=512, 10=1024 (default: 8). " + "Range: 4-10. Fewer bits = coarser Stage 1 but faster histogram; more bits = finer but slower.") args = parser.parse_args() results = run_benchmark(args) diff --git a/csrc/register.cc b/csrc/register.cc index b968e9c5..b2d12b9b 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -18,13 +18,16 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false); + py::arg("mapping_noscale") = false, + py::arg("sample_stride") = 1, + py::arg("radix_bits") = 8); m.def("topk_output_sglang_ori", &topk_output_sglang_ori, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), py::arg("eff_batch_size"), py::arg("topk_val"), py::arg("reserved_bos"), py::arg("reserved_eos"), - py::arg("max_num_pages")); + py::arg("max_num_pages"), + py::arg("radix_bits") = 8); m.def("topk_profile_histogram", &topk_profile_histogram, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("histograms"), py::arg("eff_batch_size"), @@ -34,7 +37,8 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none(), py::arg("mapping_noscale") = false, - py::arg("topk_val") = 0); + py::arg("topk_val") = 0, + py::arg("sample_stride") = 1); m.def("topk_profile_stage1", &topk_profile_stage1, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), diff --git a/csrc/register.h b/csrc/register.h index d4a311ba..784b754b 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -100,7 +100,9 @@ const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false +const bool mapping_noscale = false, +const int64_t sample_stride = 1, +const int64_t radix_bits = 8 ); void topk_output_sglang_ori( @@ -113,7 +115,8 @@ const int64_t eff_batch_size, const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t max_num_pages +const int64_t max_num_pages, +const int64_t radix_bits = 8 ); void topk_profile_histogram( @@ -128,7 +131,8 @@ const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt, const bool mapping_noscale = false, -const int64_t topk_val = 0 +const int64_t topk_val = 0, +const int64_t sample_stride = 1 ); void topk_profile_stage1( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 08930083..773cdeb9 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -30,7 +30,7 @@ enum TopKMappingMode { MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping MAPPING_POWER = 3, // Monotonic power transform MAPPING_LOG = 4, // Log transform - MAPPING_INDEX_CACHE = 5, // Sentinel: reuse previous layer's indices (Python-level skip) + // Mode 5 reserved (previously INDEX_CACHE, removed) MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 1d12c309..46dcdd79 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -1,10 +1,8 @@ /** - * @NOTE: This file is adapted from - * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py - * We: - * 1. adapt from tilelang to pure cuda - * 2. optimize the performance a little - * 3. fix the potential illegal memory access + * Vortex TopK kernel — mirrors topk_slgang_ori.cu structure with additions: + * - bf16 support, flexible radix, mapping/remap modes + * - CSR paged wrapper kernels for vortex integration + * Profiling kernels are in topk_sglang_profile.cu. */ #include #include @@ -89,9 +87,38 @@ __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); } -// Include mapping strategies (must come after convert_to_uint8 definition) +// ---- Vortex additions ---- + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +constexpr int VORTEX_MAX_TOPK = 2048; + +constexpr int COUNTER_THRESHOLD_BIN = 0; +constexpr int COUNTER_NUM_ABOVE = 1; +constexpr int COUNTER_NUM_EQUAL = 2; +constexpr int COUNTER_REMAINING_K = 3; +constexpr int COUNTER_REFINE_ROUNDS = 4; +constexpr int COUNTER_STAGE2_INPUT = 5; +constexpr int NUM_TOPK_COUNTERS = 6; + +template +__device__ __forceinline__ uint16_t convert_to_uintN(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return key >> (16 - BITS); +} + #include "topk_mapping.cuh" + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { // An optimized topk kernel copied from tilelang kernel // We assume length > TopK here, or it will crash @@ -435,38 +462,11 @@ void setup_kernel_smem_once() { TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); } -// ====================================================================== -// Vortex integration: BOS/EOS-aware segmented TopK with index remapping -// ====================================================================== - -template -__device__ __forceinline__ float vortex_to_float(T x); - -template <> -__device__ __forceinline__ float vortex_to_float(float x) { return x; } - -template <> -__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { - return __bfloat162float(x); -} - -constexpr int VORTEX_MAX_TOPK = 4096; - -// Per-segment diagnostic counters written by WriteCounters mode -constexpr int COUNTER_THRESHOLD_BIN = 0; // Stage 1 coarse threshold bin id -constexpr int COUNTER_NUM_ABOVE = 1; // elements routed above threshold in Stage 1 -constexpr int COUNTER_NUM_EQUAL = 2; // elements in threshold bin (Stage 2 input) -constexpr int COUNTER_REMAINING_K = 3; // topk slots remaining after Stage 1 routing -constexpr int COUNTER_REFINE_ROUNDS = 4; // Stage 2 rounds used (0 = resolved in Stage 1) -constexpr int COUNTER_STAGE2_INPUT = 5; // candidates entering first Stage 2 refine round -constexpr int NUM_TOPK_COUNTERS = 6; - // ====================================================================== // Ori fast path: zero-overhead topk with no mapping infrastructure. -// Adapted from topk_slgang_ori.cu — uses direct convert_to_uint8() -// for Stage 1 binning with no pre-pass, no LUT, no bin cache. +// Template on RADIX_BITS: 4-10 (16 to 1024 bins). // ====================================================================== -template +template __device__ void fast_topk_ori( const ScoreT* __restrict__ input, int* __restrict__ index, @@ -476,10 +476,13 @@ __device__ void fast_topk_ori( { int topk = target_k; constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; + constexpr auto RADIX = 1 << RADIX_BITS; + constexpr auto RADIX_PAD = RADIX / 2; constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + static_assert(RADIX_BITS >= 4 && RADIX_BITS <= 10, "RADIX_BITS must be 4-10"); + static_assert(RADIX <= BLOCK_SIZE, "RADIX must not exceed BLOCK_SIZE"); - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_histogram_buf[2][RADIX + RADIX_PAD]; alignas(128) __shared__ int s_counter; alignas(128) __shared__ int s_threshold_bin_id; alignas(128) __shared__ int s_num_input[2]; @@ -489,20 +492,18 @@ __device__ void fast_topk_ori( const int tx = threadIdx.x; - // Stage 1: 8-bit coarse histogram (direct convert_to_uint8, no mapping) + // Stage 1: coarse histogram with RADIX bins if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + const auto bin = convert_to_uintN(vortex_to_float(input[idx + row_start])); ::atomicAdd(&s_histogram[bin], 1); } __syncthreads(); const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); + for (int i = 0; i < RADIX_BITS; ++i) { if (C10_LIKELY(tx < RADIX)) { const auto j = 1 << i; const auto k = i & 1; @@ -515,6 +516,21 @@ __device__ void fast_topk_ori( __syncthreads(); } }; + // Stage 2 cumsum: always 256 sub-bins (8-bit radix on raw float bits) + const auto run_cumsum_s2 = [&] { + for (int i = 0; i < 8; ++i) { + if (C10_LIKELY(tx < 256)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < 256 - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; run_cumsum(); if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { @@ -529,7 +545,7 @@ __device__ void fast_topk_ori( if (topk == 0) { for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uint8(vortex_to_float(input[idx + row_start]))); + const auto bin = static_cast(convert_to_uintN(vortex_to_float(input[idx + row_start]))); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&s_counter, 1); index[pos] = idx; @@ -539,14 +555,12 @@ __device__ void fast_topk_ori( return; } else { __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } + if (tx < 257) s_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uint8(raw_input)); + const auto bin = static_cast(convert_to_uintN(raw_input)); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&s_counter, 1); index[pos] = idx; @@ -572,8 +586,8 @@ __device__ void fast_topk_ori( const auto _raw_num_input = s_num_input[r_idx]; const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + run_cumsum_s2(); + if (tx < 256 && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { s_threshold_bin_id = tx; s_num_input[r_idx ^ 1] = 0; s_last_remain = topk - s_histogram[tx + 1]; @@ -597,9 +611,7 @@ __device__ void fast_topk_ori( break; } else { __syncthreads(); - if (tx < RADIX + 1) { - s_histogram[tx] = 0; - } + if (tx < 257) s_histogram[tx] = 0; __syncthreads(); for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = s_input_idx[r_idx][i]; @@ -636,7 +648,7 @@ __device__ void fast_topk_ori( // - ScoreT: float or __nv_bfloat16 // - StopAfterStage1: return after Stage 1 route/filter (for profiling) // - WriteCounters: write diagnostic counters to global memory -// - target_k: runtime parameter (replaces compile-time TopK) + // - mapping: configurable value-remapping for Stage 1 bin assignment template __device__ void fast_topk_vortex( @@ -850,40 +862,52 @@ __device__ void fast_topk_vortex( } __syncthreads(); } else if (needs_topk_window(mapping.mode)) { - // Lightweight topk-window pre-pass: compute min/max of raw values, - // then focus all 256 bins on [tau_low, max] where - // tau_low = max - (max - min) * rho * k / length. - // Like mode 12 but uses a simple heuristic instead of quantile estimation. - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + // Topk-window pre-pass with streaming variance heuristic. + // tau_low = max - rho * sigma * sqrt(2 * log(n/k)) + float local_max = -__FLT_MAX__; + float local_sum = 0.0f, local_sum_sq = 0.0f; for (int idx = tx; idx < length; idx += BLOCK_SIZE) { float val = vortex_to_float(input[idx + row_start]); - local_min = fminf(local_min, val); local_max = fmaxf(local_max, val); + local_sum += val; + local_sum_sq += val * val; } for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); } - __shared__ float s_warp_mins_tw2[32], s_warp_maxs_tw2[32]; + __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; { int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins_tw2[warp_id] = local_min; s_warp_maxs_tw2[warp_id] = local_max; } + if (lane_id == 0) { + s_warp_maxs_tw2[warp_id] = local_max; + s_warp_sums_tw2[warp_id] = local_sum; + s_warp_sq_tw2[warp_id] = local_sum_sq; + } } __syncthreads(); if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins_tw2[tx]; local_max = s_warp_maxs_tw2[tx]; + local_max = s_warp_maxs_tw2[tx]; + local_sum = s_warp_sums_tw2[tx]; + local_sum_sq = s_warp_sq_tw2[tx]; for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); } if (tx == 0) { float rho = mapping.power_exp; if (rho <= 0.0f) rho = 4.0f; int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float frac = rho * float(k) / float(length); - frac = fminf(frac, 1.0f); - float tau_low = local_max - (local_max - local_min) * frac; - if (tau_low >= local_max) tau_low = local_min; + float n = float(length); + float mean = local_sum / n; + float var = local_sum_sq / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max - rho * sigma * z; + if (tau_low >= local_max) tau_low = local_max - 1.0f; float range = local_max - tau_low; s_range_min = tau_low; s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; @@ -1130,8 +1154,8 @@ void TopKOutput_Kernel( } } -// Ori fast-path wrapper: zero mapping overhead -template +// Ori fast-path wrapper: zero mapping overhead, flexible radix +template __global__ __launch_bounds__(kThreadsPerBlock) void TopKOutput_Ori_Kernel( const ScoreT* __restrict__ score, @@ -1157,7 +1181,7 @@ void TopKOutput_Ori_Kernel( + page_reserved_bos; __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); + fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); __syncthreads(); const int tx = threadIdx.x; @@ -1166,300 +1190,29 @@ void TopKOutput_Ori_Kernel( } } -// ====================================================================== -// Profiling Stage1 kernel: runs pre-pass + hist + cumsum + route/filter, -// stops before Stage 2 refinement (for sub-phase timing) -// ====================================================================== +// Helper: launch TopKOutput_Ori_Kernel with radix_bits dispatch template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKStage1_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) +void launch_ori_kernel( + const ScoreT* score, const int* dense_kv_indptr, const int* sparse_kv_indptr, + const int* dense_kv_indices, int* sparse_kv_indices, + int topk_val, int reserved_bos, int reserved_eos, + int radix_bits, dim3 nblks, dim3 nthreads, cudaStream_t stream) { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex( - score_blk, s_indices, 0, nblk, topk_val, mapping); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } -} - -// ====================================================================== -// Profiling counters kernel: runs full pipeline + writes diagnostic -// counters to a separate global-memory tensor -// ====================================================================== -template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKCounters_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - int* __restrict__ counters, // [eff_batch_size, NUM_TOPK_COUNTERS] - const int topk_val, - const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) -{ - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex( - score_blk, s_indices, 0, nblk, topk_val, mapping, - counters + bx * NUM_TOPK_COUNTERS); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } -} - -// ====================================================================== -// Profiling histogram kernel: runs only Stage 1 and returns per-segment -// 256-bin histograms for distribution analysis -// ====================================================================== -template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKHistogram_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - int* __restrict__ histograms, // [eff_batch_size, 256] - const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) -{ - constexpr auto RADIX = 256; - constexpr auto BLOCK_SIZE = kThreadsPerBlock; - __shared__ int s_histogram[RADIX]; - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; - __shared__ float s_range_min, s_range_inv_range; - - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - - const ScoreT* __restrict__ score_blk = score + start; - - // Load mapping tables into shared memory if needed - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } - - // Pre-pass: compute per-block min/max for transform modes (supports sampled stride) - if (needs_auto_range(mapping.mode) && !mapping.noscale) { - const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx * stride; idx < nblk; idx += BLOCK_SIZE * stride) { - float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); - local_min = fminf(local_min, val); - local_max = fmaxf(local_max, val); - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - __shared__ float s_warp_mins[32], s_warp_maxs[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - if (tx == 0) { - s_range_min = local_min; - float range = local_max - local_min; - s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else if (needs_pivot(mapping.mode)) { - // Pivot pre-pass: compute mean for MAPPING_SUBTRACT - float local_sum = 0.0f; - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - local_sum += vortex_to_float(score_blk[idx]); - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - __shared__ float s_warp_sums_h[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_sums_h[warp_id] = local_sum; - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_sum = s_warp_sums_h[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - if (tx == 0) { - s_range_min = local_sum / float(nblk); - s_range_inv_range = 0.0f; - } - } - __syncthreads(); - } else if (needs_tail_window(mapping.mode)) { - // Adaptive tail-window pre-pass (histogram kernel variant) - constexpr int MAX_SAMPLES_H = 1024; - __shared__ float s_samples_h[MAX_SAMPLES_H]; - __shared__ int s_sample_count_h; - - if (tx == 0) s_sample_count_h = 0; - __syncthreads(); - - const int desired_stride = (nblk + MAX_SAMPLES_H - 1) / MAX_SAMPLES_H; - const int sample_stride_h = max(desired_stride, 1); - - float local_max = -__FLT_MAX__; - for (int idx = tx * sample_stride_h; idx < nblk; idx += BLOCK_SIZE * sample_stride_h) { - float val = vortex_to_float(score_blk[idx]); - local_max = fmaxf(local_max, val); - int slot = ::atomicAdd(&s_sample_count_h, 1); - if (slot < MAX_SAMPLES_H) s_samples_h[slot] = val; - } - - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - __shared__ float s_warp_maxs_h[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_maxs_h[warp_id] = local_max; - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_h[tx]; - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - if (tx == 0) s_warp_maxs_h[0] = local_max; - } - __syncthreads(); - local_max = s_warp_maxs_h[0]; - - int nsamp = min(s_sample_count_h, MAX_SAMPLES_H); - - __syncthreads(); - if (nsamp >= 2) { - for (int pass = 0; pass < nsamp; ++pass) { - if (tx * 2 + 1 < nsamp) { - int i = tx * 2; - if (s_samples_h[i] > s_samples_h[i + 1]) { - float tmp = s_samples_h[i]; - s_samples_h[i] = s_samples_h[i + 1]; - s_samples_h[i + 1] = tmp; - } - } - __syncthreads(); - if (tx * 2 + 2 < nsamp) { - int i = tx * 2 + 1; - if (s_samples_h[i] > s_samples_h[i + 1]) { - float tmp = s_samples_h[i]; - s_samples_h[i] = s_samples_h[i + 1]; - s_samples_h[i + 1] = tmp; - } - } - __syncthreads(); - } - } - - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = mapping.target_k; - float frac = (k > 0 && nblk > 0) ? 1.0f - rho * float(k) / float(nblk) : 0.0f; - frac = fmaxf(frac, 0.0f); - - float tau_low; - if (nsamp < 4 || frac <= 0.0f) { - tau_low = -__FLT_MAX__; - } else { - float fidx = frac * float(nsamp - 1); - int lo = __float2int_rd(fidx); - lo = min(max(lo, 0), nsamp - 2); - float t = fidx - float(lo); - tau_low = s_samples_h[lo] * (1.0f - t) + s_samples_h[lo + 1] * t; - } - - if (tau_low >= local_max) { - tau_low = (nsamp > 0) ? s_samples_h[0] : local_max; - } - - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - __syncthreads(); - } else { - if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } - __syncthreads(); - } - - // Initialize shared histogram - if (tx < RADIX) s_histogram[tx] = 0; - __syncthreads(); - - // Build histogram over the segment with mapping - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - const auto bin = mapped_convert_to_uint8( - vortex_to_float(score_blk[idx]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - // Write to global memory - int* __restrict__ out = histograms + bx * RADIX; - if (tx < RADIX) { - out[tx] = s_histogram[tx]; + #define LAUNCH_ORI(BITS) \ + setup_kernel_smem_once, kSmem>(); \ + TopKOutput_Ori_Kernel<<>>( \ + score, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, \ + topk_val, reserved_bos, reserved_eos) + switch (radix_bits) { + case 4: LAUNCH_ORI(4); break; + case 5: LAUNCH_ORI(5); break; + case 6: LAUNCH_ORI(6); break; + case 7: LAUNCH_ORI(7); break; + case 9: LAUNCH_ORI(9); break; + case 10: LAUNCH_ORI(10); break; + default: LAUNCH_ORI(8); break; } + #undef LAUNCH_ORI } } // namespace @@ -1595,11 +1348,15 @@ void topk_output_sglang( const double mapping_power, std::optional mapping_lut, std::optional mapping_quantiles, - const bool mapping_noscale) + const bool mapping_noscale, + const int64_t sample_stride, + const int64_t radix_bits) { TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, "topk_output: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, + "topk_output: radix_bits must be 4-10, got ", radix_bits); // Build mapping params from optional tensors TopKMappingParams mapping{}; @@ -1608,7 +1365,7 @@ void topk_output_sglang( mapping.lut = nullptr; mapping.quantiles = nullptr; mapping.noscale = mapping_noscale; - mapping.sample_stride = 1; + mapping.sample_stride = static_cast(sample_stride); mapping.target_k = static_cast(topk_val); if (mapping_lut.has_value()) { @@ -1633,27 +1390,19 @@ void topk_output_sglang( // Fast path for mode 0 (MAPPING_NONE): use ori kernel with zero mapping overhead if (mapping_mode == MAPPING_NONE) { if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Ori_Kernel<__nv_bfloat16><<>>( + launch_ori_kernel<__nv_bfloat16>( reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Ori_Kernel<<>>( + launch_ori_kernel( x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); } else { TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); } @@ -1705,11 +1454,14 @@ void topk_output_sglang_ori( const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, - const int64_t max_num_pages) + const int64_t max_num_pages, + const int64_t radix_bits) { TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, "topk_output_sglang_ori: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, + "topk_output_sglang_ori: radix_bits must be 4-10, got ", radix_bits); CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); @@ -1722,27 +1474,19 @@ void topk_output_sglang_ori( cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Ori_Kernel<__nv_bfloat16><<>>( + launch_ori_kernel<__nv_bfloat16>( reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Ori_Kernel<<>>( + launch_ori_kernel( x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos); + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); } else { TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); } @@ -1751,262 +1495,3 @@ void topk_output_sglang_ori( TORCH_CHECK(result == cudaSuccess, "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); } - -// ====================================================================== -// Profiling: collect per-segment 256-bin histograms of Stage 1 bins -// ====================================================================== -void topk_profile_histogram( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - at::Tensor& histograms, - const int64_t eff_batch_size, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale, - const int64_t topk_val) -{ - CHECK_CUDA(x); - CHECK_CUDA(dense_kv_indptr); - CHECK_CUDA(histograms); - TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size - && histograms.size(1) == 256, - "histograms must be [eff_batch_size, 256]"); - TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, - "histograms must be int32"); - - // Build mapping params - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = 1; - mapping.target_k = static_cast(topk_val); - - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - TopKHistogram_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - histograms.data_ptr(), - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - TopKHistogram_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - histograms.data_ptr(), - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_histogram: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); -} - -// Helper: build TopKMappingParams from host arguments -static TopKMappingParams build_mapping_params( - int64_t mapping_mode, double mapping_power, - std::optional& mapping_lut, - std::optional& mapping_quantiles, - bool mapping_noscale = false, - int sample_stride = 1, - int target_k = 0) -{ - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = sample_stride; - mapping.target_k = target_k; - - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } - return mapping; -} - -// ====================================================================== -// Profiling: Stage 1 only (pre-pass + hist + cumsum + route/filter) -// ====================================================================== -void topk_profile_stage1( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages, - const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale) -{ - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_profile_stage1: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, - mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKStage1_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKStage1_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_stage1: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_stage1 kernel failed: ", ::cudaGetErrorString(result)); -} - -// ====================================================================== -// Profiling: full pipeline + diagnostic counters -// ====================================================================== -void topk_profile_counters( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - at::Tensor& counters, - const int64_t eff_batch_size, - const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages, - const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale) -{ - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_profile_counters: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - CHECK_CUDA(counters); - TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size - && counters.size(1) == NUM_TOPK_COUNTERS, - "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); - TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, - "counters must be int32"); - - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, - mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKCounters_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - counters.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKCounters_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - counters.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_counters: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); -} - diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu new file mode 100644 index 00000000..6aeac4b8 --- /dev/null +++ b/csrc/topk_sglang_profile.cu @@ -0,0 +1,1203 @@ +/** + * TopK profiling kernels: histogram collection, stage-1-only timing, + * and diagnostic counter collection. + * + * Separated from topk_sglang.cu to reduce template instantiation + * pressure on CUDA shared memory resources. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace { + + +constexpr int TopK = 2048; +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +// On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a +// per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; // bytes +#endif +#else +// Reduced from 128KB to 32KB to improve occupancy. +// Each radix pass needs at most ~TopK candidates in the threshold bin, +// so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) +#endif + +struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; +}; + +// when length <= TopK, we can directly write the indices +__device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } +} + +// keep the first `length` entries, set others to -1 +__device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } +} + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +constexpr int VORTEX_MAX_TOPK = 2048; +constexpr int COUNTER_THRESHOLD_BIN = 0; +constexpr int COUNTER_NUM_ABOVE = 1; +constexpr int COUNTER_NUM_EQUAL = 2; +constexpr int COUNTER_REMAINING_K = 3; +constexpr int COUNTER_REFINE_ROUNDS = 4; +constexpr int COUNTER_STAGE2_INPUT = 5; +constexpr int NUM_TOPK_COUNTERS = 6; + +#include "topk_mapping.cuh" + +// - mapping: configurable value-remapping for Stage 1 bin assignment +template +__device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams& mapping, + int* counters = nullptr) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + // Shared memory for mapping LUT / quantiles (loaded once per block) + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + // Auto-range for transform modes (3/4/6/7) + __shared__ float s_range_min, s_range_inv_range; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max of transformed values for linear bucketing. + // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; + // the approximated range may miss extreme outliers but Stage 2 uses raw + // float bits for exact ordering, so correctness is preserved. + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { + float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + // Cross-warp reduction via shared memory + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean of all elements, store in s_range_min. + // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering + // around the mean helps distribute values more evenly across bins. + float local_sum = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(input[idx + row_start]); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(length); // mean as pivot + s_range_inv_range = 0.0f; + } + } + __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) + // and local_max via a sampled quantile estimator. All 256 coarse bins + // are then allocated to [tau_low, local_max]; scores below tau_low + // collapse into bin 0 via linear_map_to_uint8 clamping. + constexpr int MAX_SAMPLES = 1024; + __shared__ float s_samples[MAX_SAMPLES]; + __shared__ int s_sample_count; + + if (tx == 0) s_sample_count = 0; + __syncthreads(); + + // Compute sampling stride so we collect ~MAX_SAMPLES from the segment + const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; + const int sample_stride = max(desired_stride, 1); + + // Each thread samples elements and finds local_max simultaneously + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count, 1); + if (slot < MAX_SAMPLES) { + s_samples[slot] = val; + } + } + + // Reduce local_max across block + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_tw[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_tw[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_tw[0]; + + int nsamp = min(s_sample_count, MAX_SAMPLES); + + // Simple odd-even transposition sort on the sample buffer. + // nsamp <= 1024, and we have 1024 threads, so each thread + // handles one element. O(nsamp) parallel rounds suffice. + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + // Even phase: compare (0,1), (2,3), ... + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + // Odd phase: compare (1,2), (3,4), ... + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + // Estimate tau_low = Q(1 - rho * k / n) + if (tx == 0) { + float rho = mapping.power_exp; // reused as tail expansion factor + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = 1.0f - rho * float(k) / float(length); + frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + // Too few samples or the tail covers everything: full range + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; + } + + // Fallback: if tau_low >= local_max, use full-range linear mapping + if (tau_low >= local_max) { + // Find the actual minimum from sorted samples + tau_low = (nsamp > 0) ? s_samples[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Topk-window pre-pass with streaming variance heuristic. + float local_max = -__FLT_MAX__; + float local_sum = 0.0f, local_sum_sq = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + local_sum += val; + local_sum_sq += val * val; + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { + s_warp_maxs_tw2[warp_id] = local_max; + s_warp_sums_tw2[warp_id] = local_sum; + s_warp_sq_tw2[warp_id] = local_sum_sq; + } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw2[tx]; + local_sum = s_warp_sums_tw2[tx]; + local_sum_sq = s_warp_sq_tw2[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float n = float(length); + float mean = local_sum / n; + float var = local_sum_sq / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max - rho * sigma * z; + if (tau_low >= local_max) tau_low = local_max - 1.0f; + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Stage 1: 8-bit coarse histogram (with optional mapping) + // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) + // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. + // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. + constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries + uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); + const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); + + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&vh_histogram[bin], 1); + if (use_bin_cache) { + bin_cache[idx] = bin; + } + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin; + counters[COUNTER_REMAINING_K] = topk; + } + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = 0; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; + counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; + } + if (StopAfterStage1) return; + } + + // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) + if constexpr (WriteCounters) { + // Default: all 4 rounds used; overwritten at break if resolved early + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + } +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if constexpr (WriteCounters) { + if (tx == 0 && counters) { + counters[COUNTER_REFINE_ROUNDS] = round + 1; + } + } + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKStage1_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex( + score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// ====================================================================== +// Profiling counters kernel: runs full pipeline + writes diagnostic +// counters to a separate global-memory tensor +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKCounters_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + int* __restrict__ counters, // [eff_batch_size, NUM_TOPK_COUNTERS] + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex( + score_blk, s_indices, 0, nblk, topk_val, mapping, + counters + bx * NUM_TOPK_COUNTERS); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// ====================================================================== +// Profiling histogram kernel: runs only Stage 1 and returns per-segment +// 256-bin histograms for distribution analysis +// ====================================================================== +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKHistogram_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + int* __restrict__ histograms, // [eff_batch_size, 256] + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + constexpr auto RADIX = 256; + constexpr auto BLOCK_SIZE = kThreadsPerBlock; + __shared__ int s_histogram[RADIX]; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + __shared__ float s_range_min, s_range_inv_range; + + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + + const ScoreT* __restrict__ score_blk = score + start; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max for transform modes (supports sampled stride) + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx * stride; idx < nblk; idx += BLOCK_SIZE * stride) { + float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean for MAPPING_SUBTRACT + float local_sum = 0.0f; + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(score_blk[idx]); + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums_h[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums_h[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums_h[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(nblk); + s_range_inv_range = 0.0f; + } + } + __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass (histogram kernel variant) + constexpr int MAX_SAMPLES_H = 1024; + __shared__ float s_samples_h[MAX_SAMPLES_H]; + __shared__ int s_sample_count_h; + + if (tx == 0) s_sample_count_h = 0; + __syncthreads(); + + const int desired_stride = (nblk + MAX_SAMPLES_H - 1) / MAX_SAMPLES_H; + const int sample_stride_h = max(desired_stride, 1); + + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride_h; idx < nblk; idx += BLOCK_SIZE * sample_stride_h) { + float val = vortex_to_float(score_blk[idx]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count_h, 1); + if (slot < MAX_SAMPLES_H) s_samples_h[slot] = val; + } + + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_h[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_h[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_h[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_h[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_h[0]; + + int nsamp = min(s_sample_count_h, MAX_SAMPLES_H); + + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples_h[i] > s_samples_h[i + 1]) { + float tmp = s_samples_h[i]; + s_samples_h[i] = s_samples_h[i + 1]; + s_samples_h[i + 1] = tmp; + } + } + __syncthreads(); + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples_h[i] > s_samples_h[i + 1]) { + float tmp = s_samples_h[i]; + s_samples_h[i] = s_samples_h[i + 1]; + s_samples_h[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = mapping.target_k; + float frac = (k > 0 && nblk > 0) ? 1.0f - rho * float(k) / float(nblk) : 0.0f; + frac = fmaxf(frac, 0.0f); + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples_h[lo] * (1.0f - t) + s_samples_h[lo + 1] * t; + } + + if (tau_low >= local_max) { + tau_low = (nsamp > 0) ? s_samples_h[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Topk-window pre-pass with streaming variance (histogram kernel variant) + float local_max_h = -__FLT_MAX__; + float local_sum_h = 0.0f, local_sum_sq_h = 0.0f; + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + float val = vortex_to_float(score_blk[idx]); + local_max_h = fmaxf(local_max_h, val); + local_sum_h += val; + local_sum_sq_h += val * val; + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_max_h = fmaxf(local_max_h, __shfl_xor_sync(0xFFFFFFFF, local_max_h, offset)); + local_sum_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_h, offset); + local_sum_sq_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq_h, offset); + } + __shared__ float s_warp_maxs_tw3[32], s_warp_sums_tw3[32], s_warp_sq_tw3[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { + s_warp_maxs_tw3[warp_id] = local_max_h; + s_warp_sums_tw3[warp_id] = local_sum_h; + s_warp_sq_tw3[warp_id] = local_sum_sq_h; + } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max_h = s_warp_maxs_tw3[tx]; + local_sum_h = s_warp_sums_tw3[tx]; + local_sum_sq_h = s_warp_sq_tw3[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_max_h = fmaxf(local_max_h, __shfl_xor_sync(0xFFFFFFFF, local_max_h, offset)); + local_sum_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_h, offset); + local_sum_sq_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq_h, offset); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = mapping.target_k; + float n = float(nblk); + float mean = local_sum_h / n; + float var = local_sum_sq_h / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max_h - rho * sigma * z; + if (tau_low >= local_max_h) tau_low = local_max_h - 1.0f; + float range = local_max_h - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Initialize shared histogram + if (tx < RADIX) s_histogram[tx] = 0; + __syncthreads(); + + // Build histogram over the segment with mapping + for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(score_blk[idx]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + // Write to global memory + int* __restrict__ out = histograms + bx * RADIX; + if (tx < RADIX) { + out[tx] = s_histogram[tx]; + } +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +// ====================================================================== +// Profiling: collect per-segment 256-bin histograms of Stage 1 bins +// ====================================================================== +void topk_profile_histogram( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + at::Tensor& histograms, + const int64_t eff_batch_size, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale, + const int64_t topk_val, + const int64_t sample_stride) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(histograms); + TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size + && histograms.size(1) == 256, + "histograms must be [eff_batch_size, 256]"); + TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, + "histograms must be int32"); + + // Build mapping params + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = static_cast(sample_stride); + mapping.target_k = static_cast(topk_val); + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + TopKHistogram_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + TopKHistogram_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_histogram: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); +} + +// Helper: build TopKMappingParams from host arguments +static TopKMappingParams build_mapping_params( + int64_t mapping_mode, double mapping_power, + std::optional& mapping_lut, + std::optional& mapping_quantiles, + bool mapping_noscale = false, + int sample_stride = 1, + int target_k = 0) +{ + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + mapping.noscale = mapping_noscale; + mapping.sample_stride = sample_stride; + mapping.target_k = target_k; + + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + CHECK_CUDA(lut); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + mapping.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + CHECK_CUDA(q); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + mapping.quantiles = q.data_ptr(); + } + return mapping; +} + +// ====================================================================== +// Profiling: Stage 1 only (pre-pass + hist + cumsum + route/filter) +// ====================================================================== +void topk_profile_stage1( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_stage1: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKStage1_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKStage1_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_stage1: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_stage1 kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling: full pipeline + diagnostic counters +// ====================================================================== +void topk_profile_counters( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& counters, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles, + const bool mapping_noscale) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_counters: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + CHECK_CUDA(counters); + TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size + && counters.size(1) == NUM_TOPK_COUNTERS, + "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); + TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, + "counters must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, + mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKCounters_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKCounters_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, + reserved_bos, + reserved_eos, + mapping); + } else { + TORCH_CHECK(false, + "topk_profile_counters: unsupported dtype ", + x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); +} + diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 6806eca5..fcc2ff1f 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -47,6 +47,9 @@ MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" +RADIX_BITS=8 +SAMPLE_STRIDE=1 +SEQ_LEN=32768 # The path to the raw_histograms.npy file (set to skip calibration) REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" REAL_HISTOGRAMS="" @@ -59,12 +62,23 @@ while [[ $# -gt 0 ]]; do --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --radix-bits) RADIX_BITS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" +# Validate seq_len: need pages/seg > topk_val (page_size=16, reserved=3 pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * 16 )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) @@ -77,6 +91,8 @@ echo " Model: ${MODEL_NAME}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" echo " GPU: ${GPU_ID}" +echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" +echo " Sample stride: ${SAMPLE_STRIDE}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" echo " Output: ${RUN_DIR}" echo "============================================================" @@ -117,7 +133,7 @@ fi PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ --batch-size 4 \ - --seq-len 32768 \ + --seq-len ${SEQ_LEN} \ --num-kv-heads 2 \ "${AUTOTUNE_EXTRA_ARGS[@]}" \ --output-json "${AUTOTUNE_JSON}" \ @@ -157,7 +173,7 @@ fi PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 \ - --seq-lens 32768 \ + --seq-lens ${SEQ_LEN} \ --topk-vals "${TOPK_VAL}" \ --num-kv-heads 8 \ --distributions bucket_uniform normal \ @@ -166,6 +182,8 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ "${BENCH_EXTRA_ARGS[@]}" \ --autotune-json "${AUTOTUNE_JSON}" \ --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ + --radix-bits "${RADIX_BITS}" \ + --sample-stride "${SAMPLE_STRIDE}" \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index 1f89c0b9..65e4f413 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -31,9 +31,12 @@ BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" -TOPK_VAL=30 +TOPK_VAL=2048 MEM=0.7 ALGO="block_sparse_attention" +RADIX_BITS=8 +SAMPLE_STRIDE=1 +SEQ_LEN=65536 # The path to the raw_histograms.npy file (set to skip calibration) REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" # REAL_HISTOGRAMS="" @@ -46,12 +49,23 @@ while [[ $# -gt 0 ]]; do --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --radix-bits) RADIX_BITS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" +# Validate seq_len: need pages/seg > topk_val (page_size=16, reserved=3 pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * 16 )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) @@ -63,7 +77,10 @@ echo "Bucket Distribution Profiling (modes 3, 6, 7, 8, 9, 10, 11)" echo " Model: ${MODEL_NAME}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / 16 )) pages/seg)" echo " GPU: ${GPU_ID}" +echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" +echo " Sample stride: ${SAMPLE_STRIDE}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" echo " Output: ${RUN_DIR}" echo "============================================================" @@ -98,7 +115,7 @@ AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ --batch-size 4 \ - --seq-len 32768 \ + --seq-len ${SEQ_LEN} \ --num-kv-heads 8 \ --real-histograms "${REAL_HIST_PATH}" \ --latency-rerank \ @@ -115,7 +132,7 @@ BENCH_JSON="${RUN_DIR}/bench_distribution.json" PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --batch-sizes 4 \ - --seq-lens 32768 \ + --seq-lens ${SEQ_LEN} \ --topk-vals "${TOPK_VAL}" \ --num-kv-heads 8 \ --distributions bucket_uniform normal \ @@ -124,6 +141,8 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --real-histograms "${REAL_HIST_PATH}" \ --autotune-json "${AUTOTUNE_JSON}" \ --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ + --radix-bits "${RADIX_BITS}" \ + --sample-stride "${SAMPLE_STRIDE}" \ --repeat 20 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index 32ff5a39..a1d1b6f3 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -119,10 +119,9 @@ def verify_algos( kv_cache_dtype: str = "auto", topk_type: str = "naive", topk_mapping_mode: int = 0, -topk_mapping_power: float = 0.5, +topk_mapping_hparam: float = 0.5, topk_mapping_lut_path: str = None, topk_mapping_quantiles_path: str = None, -index_cache_shared_layers: list = None, disable_cuda_graph: bool = False, benchmark: str = "amc23", ): @@ -143,10 +142,9 @@ def verify_algos( kv_cache_dtype=kv_cache_dtype, vortex_topk_type=topk_type, vortex_topk_mapping_mode=topk_mapping_mode, - vortex_topk_mapping_power=topk_mapping_power, + vortex_topk_mapping_hparam=topk_mapping_hparam, vortex_topk_mapping_lut_path=topk_mapping_lut_path, vortex_topk_mapping_quantiles_path=topk_mapping_quantiles_path, - vortex_index_cache_shared_layers=index_cache_shared_layers, ) tokenizer = AutoTokenizer.from_pretrained(model_name) if benchmark != "amc23" else None prompts, requests = _load_benchmark(benchmark, trials, tokenizer=tokenizer) @@ -310,14 +308,15 @@ def parse_args(): type=int, default=0, choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 5=index_cache, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window, 13=exp_stretch, 14=topk_window (default: 0).', + help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window, 13=exp_stretch, 14=topk_window (default: 0).', ) parser.add_argument( - "--topk-mapping-power", + "--topk-mapping-hparam", "--topk-mapping-power", type=float, default=0.5, - help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6 asinh), alpha (mode 7 log1p), rho tail expansion (mode 12). Default: 0.5.', + dest="topk_mapping_hparam", + help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6), alpha (mode 7/9/10/13), rho (mode 12/14). Default: 0.5.', ) parser.add_argument( @@ -334,14 +333,6 @@ def parse_args(): help="Path to .npy file with float32[256] quantiles for topk mapping mode 2.", ) - parser.add_argument( - "--index-cache-shared-layers", - type=int, - nargs="+", - default=None, - help="Layer IDs that reuse indices from the nearest preceding full layer (skip indexer).", - ) - parser.add_argument( "--benchmark", type=str, @@ -356,12 +347,6 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - # --- Mode 5: Index Cache (default even-layer pattern) --- - if args.topk_mapping_mode == 5: - if args.index_cache_shared_layers is None: - args.index_cache_shared_layers = list(range(2, 28, 2)) # [2,4,6,...,26] - args.topk_mapping_mode = 0 - for bench_name in args.benchmark: if bench_name not in BENCHMARK_REGISTRY: print(f"WARNING: Unknown benchmark '{bench_name}', skipping. Available: {list(BENCHMARK_REGISTRY.keys())}") @@ -380,10 +365,9 @@ def parse_args(): kv_cache_dtype=args.kv_cache_dtype, topk_type=args.topk_type, topk_mapping_mode=args.topk_mapping_mode, - topk_mapping_power=args.topk_mapping_power, + topk_mapping_hparam=args.topk_mapping_hparam, topk_mapping_lut_path=args.topk_mapping_lut_path, topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, - index_cache_shared_layers=args.index_cache_shared_layers, benchmark=bench_name, ) summary["benchmark"] = bench_name diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index c0a03c54..9a9f482e 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -112,17 +112,17 @@ for r in data: if m not in best or r['gini'] < best[m]['gini']: best[m] = r for m in (3, 6, 7, 9, 10): - print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') + print(f'BEST_HPARAM_{m}={best[m][\"param\"]}' if m in best else f'BEST_HPARAM_{m}=0.5') " "${AUTOTUNE_JSON}")" - echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10}" + echo ">>> Autotuned best powers: mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7} mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10}" echo "" else echo ">>> WARNING: ${REAL_HISTOGRAMS} not found, using default power=0.5 for all modes" - BEST_POWER_3=0.5 - BEST_POWER_6=0.5 - BEST_POWER_7=0.5 - BEST_POWER_9=0.5 - BEST_POWER_10=0.5 + BEST_HPARAM_3=0.5 + BEST_HPARAM_6=0.5 + BEST_HPARAM_7=0.5 + BEST_HPARAM_9=0.5 + BEST_HPARAM_10=0.5 fi # ============================================================ @@ -189,8 +189,8 @@ done # Mode 3: power — autotuned best p # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_POWER_3}_${TIMESTAMP}.log" - echo ">>> Running mode 3 (power) p=${BEST_POWER_3} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_HPARAM_3}_${TIMESTAMP}.log" + echo ">>> Running mode 3 (power) p=${BEST_HPARAM_3} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -199,7 +199,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 3 \ - --topk-mapping-power ${BEST_POWER_3} \ + --topk-mapping-hparam ${BEST_HPARAM_3} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -208,8 +208,8 @@ done # Mode 6: asinh — autotuned best beta # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_POWER_6}_${TIMESTAMP}.log" - echo ">>> Running mode 6 (asinh) beta=${BEST_POWER_6} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_HPARAM_6}_${TIMESTAMP}.log" + echo ">>> Running mode 6 (asinh) beta=${BEST_HPARAM_6} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -218,7 +218,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 6 \ - --topk-mapping-power ${BEST_POWER_6} \ + --topk-mapping-hparam ${BEST_HPARAM_6} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -227,8 +227,8 @@ done # Mode 7: log1p — autotuned best alpha # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_POWER_7}_${TIMESTAMP}.log" - echo ">>> Running mode 7 (log1p) alpha=${BEST_POWER_7} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_HPARAM_7}_${TIMESTAMP}.log" + echo ">>> Running mode 7 (log1p) alpha=${BEST_HPARAM_7} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -237,7 +237,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 7 \ - --topk-mapping-power ${BEST_POWER_7} \ + --topk-mapping-hparam ${BEST_HPARAM_7} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -246,8 +246,8 @@ done # Mode 9: erf — autotuned best alpha # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_POWER_9}_${TIMESTAMP}.log" - echo ">>> Running mode 9 (erf) alpha=${BEST_POWER_9} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_HPARAM_9}_${TIMESTAMP}.log" + echo ">>> Running mode 9 (erf) alpha=${BEST_HPARAM_9} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -256,7 +256,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 9 \ - --topk-mapping-power ${BEST_POWER_9} \ + --topk-mapping-hparam ${BEST_HPARAM_9} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done @@ -265,8 +265,8 @@ done # Mode 10: tanh — autotuned best alpha # ============================================================ for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_POWER_10}_${TIMESTAMP}.log" - echo ">>> Running mode 10 (tanh) alpha=${BEST_POWER_10} (autotuned) for ${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_HPARAM_10}_${TIMESTAMP}.log" + echo ">>> Running mode 10 (tanh) alpha=${BEST_HPARAM_10} (autotuned) for ${algo}" echo ">>> Saving results to ${OUTFILE}" { time python verify_algo.py \ --trials 8 \ @@ -275,7 +275,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 10 \ - --topk-mapping-power ${BEST_POWER_10} \ + --topk-mapping-hparam ${BEST_HPARAM_10} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index 4c96a15b..6848e1ea 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -78,9 +78,9 @@ for r in data: if m not in best or r['gini'] < best[m]['gini']: best[m] = r for m in (3, 6, 7, 9, 10, 13, 14): - print(f'BEST_POWER_{m}={best[m][\"param\"]}' if m in best else f'BEST_POWER_{m}=0.5') + print(f'BEST_HPARAM_{m}={best[m][\"param\"]}' if m in best else f'BEST_HPARAM_{m}=0.5') " "${AUTOTUNE_JSON}")" -echo ">>> Autotuned best powers: mode3=${BEST_POWER_3} mode6=${BEST_POWER_6} mode7=${BEST_POWER_7} mode9=${BEST_POWER_9} mode10=${BEST_POWER_10} mode13=${BEST_POWER_13} mode14=${BEST_POWER_14}" +echo ">>> Autotuned best powers: mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7} mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode13=${BEST_HPARAM_13} mode14=${BEST_HPARAM_14}" echo "" # ============================================================ @@ -107,11 +107,11 @@ done # Step 1: Mode 3 (power) — autotuned best p # ============================================================ echo "============================================================" -echo "Step 1: Mode 3 (power) — p=${BEST_POWER_3} (autotuned)" +echo "Step 1: Mode 3 (power) — p=${BEST_HPARAM_3} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_POWER_3}_${TIMESTAMP}.log" - echo ">>> Mode 3 (power) p=${BEST_POWER_3} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_HPARAM_3}_${TIMESTAMP}.log" + echo ">>> Mode 3 (power) p=${BEST_HPARAM_3} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -119,7 +119,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 3 \ - --topk-mapping-power ${BEST_POWER_3} \ + --topk-mapping-hparam ${BEST_HPARAM_3} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -129,11 +129,11 @@ done # Step 2: Mode 6 (asinh) — autotuned best beta # ============================================================ echo "============================================================" -echo "Step 2: Mode 6 (asinh) — beta=${BEST_POWER_6} (autotuned)" +echo "Step 2: Mode 6 (asinh) — beta=${BEST_HPARAM_6} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_POWER_6}_${TIMESTAMP}.log" - echo ">>> Mode 6 (asinh) beta=${BEST_POWER_6} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_HPARAM_6}_${TIMESTAMP}.log" + echo ">>> Mode 6 (asinh) beta=${BEST_HPARAM_6} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -141,7 +141,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 6 \ - --topk-mapping-power ${BEST_POWER_6} \ + --topk-mapping-hparam ${BEST_HPARAM_6} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -151,11 +151,11 @@ done # Step 3: Mode 7 (log1p) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 3: Mode 7 (log1p) — alpha=${BEST_POWER_7} (autotuned)" +echo "Step 3: Mode 7 (log1p) — alpha=${BEST_HPARAM_7} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_POWER_7}_${TIMESTAMP}.log" - echo ">>> Mode 7 (log1p) alpha=${BEST_POWER_7} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_HPARAM_7}_${TIMESTAMP}.log" + echo ">>> Mode 7 (log1p) alpha=${BEST_HPARAM_7} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -163,7 +163,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 7 \ - --topk-mapping-power ${BEST_POWER_7} \ + --topk-mapping-hparam ${BEST_HPARAM_7} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -194,11 +194,11 @@ done # Step 5: Mode 9 (erf) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 5: Mode 9 (erf) — alpha=${BEST_POWER_9} (autotuned)" +echo "Step 5: Mode 9 (erf) — alpha=${BEST_HPARAM_9} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_POWER_9}_${TIMESTAMP}.log" - echo ">>> Mode 9 (erf) alpha=${BEST_POWER_9} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_HPARAM_9}_${TIMESTAMP}.log" + echo ">>> Mode 9 (erf) alpha=${BEST_HPARAM_9} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -206,7 +206,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 9 \ - --topk-mapping-power ${BEST_POWER_9} \ + --topk-mapping-hparam ${BEST_HPARAM_9} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -216,11 +216,11 @@ done # Step 6: Mode 10 (tanh) — autotuned best alpha # ============================================================ echo "============================================================" -echo "Step 6: Mode 10 (tanh) — alpha=${BEST_POWER_10} (autotuned)" +echo "Step 6: Mode 10 (tanh) — alpha=${BEST_HPARAM_10} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_POWER_10}_${TIMESTAMP}.log" - echo ">>> Mode 10 (tanh) alpha=${BEST_POWER_10} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_HPARAM_10}_${TIMESTAMP}.log" + echo ">>> Mode 10 (tanh) alpha=${BEST_HPARAM_10} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -228,7 +228,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 10 \ - --topk-mapping-power ${BEST_POWER_10} \ + --topk-mapping-hparam ${BEST_HPARAM_10} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -272,7 +272,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 12 \ - --topk-mapping-power 4.0 \ + --topk-mapping-hparam 4.0 \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -283,11 +283,11 @@ done # ============================================================ echo "" echo "============================================================" -echo "Step 9: Mode 13 (exp_stretch) — alpha=${BEST_POWER_13} (autotuned)" +echo "Step 9: Mode 13 (exp_stretch) — alpha=${BEST_HPARAM_13} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_13_alpha${BEST_POWER_13}_${TIMESTAMP}.log" - echo ">>> Mode 13 (exp_stretch) alpha=${BEST_POWER_13} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_13_alpha${BEST_HPARAM_13}_${TIMESTAMP}.log" + echo ">>> Mode 13 (exp_stretch) alpha=${BEST_HPARAM_13} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -295,7 +295,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 13 \ - --topk-mapping-power ${BEST_POWER_13} \ + --topk-mapping-hparam ${BEST_HPARAM_13} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -306,11 +306,11 @@ done # ============================================================ echo "" echo "============================================================" -echo "Step 10: Mode 14 (topk_window) — rho=${BEST_POWER_14} (autotuned)" +echo "Step 10: Mode 14 (topk_window) — rho=${BEST_HPARAM_14} (autotuned)" echo "============================================================" for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_14_rho${BEST_POWER_14}_${TIMESTAMP}.log" - echo ">>> Mode 14 (topk_window) rho=${BEST_POWER_14} algo=${algo}" + OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_14_rho${BEST_HPARAM_14}_${TIMESTAMP}.log" + echo ">>> Mode 14 (topk_window) rho=${BEST_HPARAM_14} algo=${algo}" { time python verify_algo.py \ --trials 8 \ --topk-val ${TOPK_VAL} \ @@ -318,7 +318,7 @@ for algo in "${sparse_algos[@]}"; do --model-name Qwen/Qwen3-1.7B \ --topk-type sglang \ --topk-mapping-mode 14 \ - --topk-mapping-power ${BEST_POWER_14} \ + --topk-mapping-hparam ${BEST_HPARAM_14} \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" @@ -343,7 +343,7 @@ PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --real-histograms "${REAL_HISTOGRAMS}" \ --autotune-json "${AUTOTUNE_JSON}" \ --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 sglang_m13 sglang_m14 \ - --mapping-power-13 ${BEST_POWER_13} --mapping-power-14 ${BEST_POWER_14} \ + --mapping-hparam-13 ${BEST_HPARAM_13} --mapping-hparam-14 ${BEST_HPARAM_14} \ --repeat 5 \ --output-json "${COUNTER_JSON}" \ 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" @@ -357,14 +357,14 @@ echo "============================================================" echo "All runs complete. Results in ${RESULTS_DIR}/" echo " Auto-tune: ${AUTOTUNE_JSON}" echo " Counters: ${COUNTER_JSON}" -echo " Mode 3 (power): p = ${BEST_POWER_3} (autotuned)" -echo " Mode 6 (asinh): beta = ${BEST_POWER_6} (autotuned)" -echo " Mode 7 (log1p): alpha = ${BEST_POWER_7} (autotuned)" +echo " Mode 3 (power): p = ${BEST_HPARAM_3} (autotuned)" +echo " Mode 6 (asinh): beta = ${BEST_HPARAM_6} (autotuned)" +echo " Mode 7 (log1p): alpha = ${BEST_HPARAM_7} (autotuned)" echo " Mode 8 (trunc8): (fixed)" -echo " Mode 9 (erf): alpha = ${BEST_POWER_9} (autotuned)" -echo " Mode 10 (tanh): alpha = ${BEST_POWER_10} (autotuned)" +echo " Mode 9 (erf): alpha = ${BEST_HPARAM_9} (autotuned)" +echo " Mode 10 (tanh): alpha = ${BEST_HPARAM_10} (autotuned)" echo " Mode 11 (subtract): (fixed)" echo " Mode 12 (tail_win): rho = 4.0" -echo " Mode 13 (exp_stretch):alpha = ${BEST_POWER_13} (autotuned)" -echo " Mode 14 (topk_window):rho = ${BEST_POWER_14} (autotuned)" +echo " Mode 13 (exp_stretch):alpha = ${BEST_HPARAM_13} (autotuned)" +echo " Mode 14 (topk_window):rho = ${BEST_HPARAM_14} (autotuned)" echo "============================================================" diff --git a/setup.py b/setup.py index 9c2186b9..0fc46ad8 100644 --- a/setup.py +++ b/setup.py @@ -18,6 +18,7 @@ 'csrc/utils_sglang.cu', 'csrc/topk.cu', 'csrc/topk_sglang.cu', + 'csrc/topk_sglang_profile.cu', ], include_dirs=['csrc'], extra_compile_args={ diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index b50ca74f..e4424cdf 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -247,7 +247,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso if self.topk_type == "sglang": # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) - mapping_power = getattr(ctx, 'topk_mapping_power', 0.5) + mapping_hparam = getattr(ctx, 'topk_mapping_hparam', getattr(ctx, 'topk_mapping_power', 0.5)) mapping_lut = getattr(ctx, 'topk_mapping_lut', None) mapping_quantiles = getattr(ctx, 'topk_mapping_quantiles', None) mapping_noscale = getattr(ctx, 'topk_mapping_noscale', False) @@ -268,7 +268,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_eos, ctx.max_num_pages_per_request, mapping_mode, - mapping_power, + mapping_hparam, mapping_lut, mapping_quantiles, mapping_noscale, @@ -320,7 +320,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_bos, ctx.page_reserved_eos, mapping_mode, - mapping_power, + mapping_hparam, mapping_lut, mapping_quantiles, mapping_noscale, From 524834a5f5448e9595d746e95966e2190d4b9622 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 13 Apr 2026 03:09:46 -0400 Subject: [PATCH 18/24] Refactor TopK mapping and benchmarking scripts for enhanced profiling and usability - Updated autotune_topk_mapping.py to optimize hyperparameter tuning based on kernel latency. - Simplified the sweep grid and improved documentation for usage. - Enhanced bench_topk.py to expose public helpers and added CLI modes for benchmarking. - Introduced new remap functions and improved kernel integration for profiling. - Added watchdog timeout option in calibrate_topk.py for SGLang scheduler. - Removed outdated greedy_layer_search.py as part of code cleanup. --- benchmarks/autotune_topk_mapping.py | 797 ++++------ benchmarks/bench_topk.py | 1139 +++++--------- benchmarks/calibrate_topk.py | 13 +- benchmarks/greedy_layer_search.py | 117 -- csrc/archived/README.md | 19 + csrc/archived/fast_topk_vortex_prepass.cu | 525 +++++++ csrc/archived/topk_mapping_full.cuh | 217 +++ csrc/archived/topk_sglang_ori_fastpath.cu | 319 ++++ csrc/{ => archived}/topk_slgang_ori.cu | 0 csrc/register.cc | 42 +- csrc/register.h | 43 +- csrc/topk_mapping.cuh | 208 +-- csrc/topk_sglang.cu | 1231 ++++++--------- csrc/topk_sglang_profile.cu | 1329 +++++------------ examples/remap_function_bench.sh | 238 +++ examples/run_distribution_analysis.sh | 254 ++-- examples/run_distribution_analysis_new.sh | 194 +-- examples/test_topk.py | 118 ++ examples/verify_algo.py | 30 +- examples/verify_algo.sh | 7 +- examples/verify_algo_topk_mapping.sh | 396 ++--- examples/verify_algo_topk_mapping_new.sh | 433 ++---- ...ackends.sh => verify_external_backends.sh} | 0 vortex_torch/indexer/context.py | 27 +- vortex_torch/indexer/output_func.py | 51 +- 25 files changed, 3553 insertions(+), 4194 deletions(-) delete mode 100644 benchmarks/greedy_layer_search.py create mode 100644 csrc/archived/README.md create mode 100644 csrc/archived/fast_topk_vortex_prepass.cu create mode 100644 csrc/archived/topk_mapping_full.cuh create mode 100644 csrc/archived/topk_sglang_ori_fastpath.cu rename csrc/{ => archived}/topk_slgang_ori.cu (100%) create mode 100755 examples/remap_function_bench.sh create mode 100644 examples/test_topk.py rename examples/{verify_sparse_backends.sh => verify_external_backends.sh} (100%) diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index 8051c14e..db213213 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -1,254 +1,131 @@ """ -Auto-tuner for TopK mapping hyperparameters. +Auto-tune TopK mapping hyperparameters by profiled kernel latency. -Sweeps all (mode, hyperparameter) combinations using the topk_hit_rate -kernel and ranks by Stage 1 resolution rate. +For each (mode, hyperparameter) combo in the sweep grid, this script runs +the fused remap+topk kernel (topk_output_sglang_fused) on synthetic or +real-distribution inputs, measures end-to-end latency with CUDA events, +and picks the hyperparameter with the lowest measured latency per mode. -Supports real-data score distributions via --real-histograms: loads the -raw_histograms.npy from calibration and synthesizes score tensors that -match the real bin distribution (by reversing the convert_to_uint8 mapping). - -Sweep grid: - - Mode 3 (power): p in [0.1, 0.25, 0.75, 0.9] - - Mode 6 (asinh): beta in [0.1, 0.5, 1, 2, 4] - - Mode 7 (log1p): alpha in [0.1, 0.5, 0.75, 1, 2, 4, 8] - - Baselines: mode 0 (none), mode 4 (log) +Distribution statistics (gini, max/mean, counter-based Stage-2 cost) are +still collected for diagnostics, but they do NOT drive the ranking — the +ranking is purely latency-driven. Usage: - python benchmarks/autotune_topk_mapping.py --topk-val 30 --real-histograms calibration/raw_histograms.npy - python benchmarks/autotune_topk_mapping.py --topk-val 30 --output-json results.json + python benchmarks/autotune_topk_mapping.py \\ + --topk-val 2048 --batch-size 4 --seq-len 65536 --num-kv-heads 8 \\ + --real-histograms calibration/raw_histograms.npy \\ + --output-json autotune_results.json """ import argparse import json import math -from typing import List +from typing import Dict, List, Optional import numpy as np import torch from bench_topk import make_topk_inputs, bench_kernel, compute_histogram_stats -from vortex_torch_C import topk_profile_histogram, topk_profile_counters, topk_output_sglang - - - -SWEEP_GRID = { - # (mode, param_name, param_values) - 3: ("power_exp", [0.1, 0.25, 0.5, 0.75, 0.9, 2.0, 4.0]), - 6: ("beta", [0.1, 0.5, 1.0, 2.0, 4.0]), - 7: ("alpha", [0.1, 0.5, 0.75, 1.0, 2.0, 4.0, 8.0]), - 9: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), - 10: ("alpha", [0.1, 0.5, 1.0, 2.0, 4.0]), - 13: ("alpha", [0.5, 1.0, 2.0, 4.0, 8.0]), - 14: ("rho", [2.0, 4.0, 8.0, 16.0]), -} -BASELINES = { - 0: ("none", 0.5), - 4: ("log", 0.5), - 8: ("trunc8", 0.5), - 11: ("subtract", 0.5), -} -# Noscale baselines for parametric transform modes (skip auto-range pre-pass) -NOSCALE_BASELINES = { - 3: ("power_noscale", [0.5]), - 6: ("asinh_noscale", [1.0]), - 7: ("log1p_noscale", [1.0]), - 9: ("erf_noscale", [1.0]), - 10: ("tanh_noscale", [1.0]), - 13: ("exp_stretch_noscale", [1.0, 4.0]), +from vortex_torch_C import ( + topk_output_sglang_fused, + topk_profile_histogram, + topk_profile_counters, +) + + +# Only parametric modes need auto-tuning. Mode 0 (none) and mode 4 (log) +# have no knob; mode 0 is always the baseline. +SWEEP_GRID: Dict[int, List[float]] = { + 3: [0.1, 0.25, 0.5, 0.75, 0.9], # power: p + 6: [0.1, 0.5, 1.0, 2.0, 4.0], # asinh: beta + 7: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0], # log1p: alpha + 9: [0.1, 0.5, 1.0, 2.0, 4.0], # erf: alpha + 10: [0.1, 0.5, 1.0, 2.0, 4.0], # tanh: alpha + 11: [-1.0, -0.5, 0.0, 0.5, 1.0], # subtract: pivot (free hparam) + 13: [0.5, 1.0, 2.0, 4.0, 8.0], # exp_stretch: alpha } + +PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 11: "pivot", 13: "alpha"} MODE_NAMES = { - 0: "none", - 3: "power", - 4: "log", - 6: "asinh", - 7: "log1p", - 8: "trunc8", - 9: "erf", - 10: "tanh", - 11: "subtract", - 13: "exp_stretch", - 14: "topk_window", + 0: "none", 1: "lut_cdf", 2: "quantile", + 3: "power", 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", 13: "exp_stretch", } +# Non-parametric modes — no knob to sweep; timed once as a reference point. +# LUT_CDF (1) and QUANTILE (2) are added here at runtime when the caller +# passes --lut-path / --quantiles-path. +BASELINES = [(0, 0.5), (4, 0.5), (8, 0.5)] -def _key_to_fp16(key: int) -> np.float16: - """Invert the convert_to_uint8 sign-flip for a single 16-bit key.""" - if key >= 0x8000: - bits = key & 0x7FFF - else: - bits = (~key) & 0xFFFF - return np.array([bits], dtype=np.uint16).view(np.float16)[0] +# ---------- Real-distribution score generation ---------- -def build_bin_range_table(): - """Build per-bin (lo, hi) fp16 value tables by iterating all 65536 fp16 bit patterns. +def _key_to_fp16(key: int) -> np.float16: + """Invert convert_to_uint8's sign-flip for a single 16-bit key.""" + bits = (key & 0x7FFF) if key >= 0x8000 else ((~key) & 0xFFFF) + return np.array([bits], dtype=np.uint16).view(np.float16)[0] - For each fp16 value, compute its bin via convert_to_uint8 logic, then track - the min/max fp16 value that lands in each bin. - Returns: - (bin_lo, bin_hi): two [256] float32 arrays — the min and max fp16 values per bin. - """ - # Generate all 65536 fp16 bit patterns +def _build_bin_range_table(): + """Return per-bin (lo, hi) fp16 value tables for all 256 radix bins.""" all_bits = np.arange(65536, dtype=np.uint16) all_fp16 = all_bits.view(np.float16) - - # Compute convert_to_uint8 for each: key = sign-flip, bin = key >> 8 keys = np.where( (all_bits & 0x8000).astype(bool), (~all_bits).astype(np.uint16), all_bits | np.uint16(0x8000), ) bins = (keys >> 8).astype(np.uint8) - - # Convert to float32 for min/max (fp16 has NaNs/Infs, filter them) all_f32 = all_fp16.astype(np.float32) valid = np.isfinite(all_f32) - bin_lo = np.full(256, np.inf, dtype=np.float32) bin_hi = np.full(256, -np.inf, dtype=np.float32) - for b in range(256): mask = (bins == b) & valid if mask.any(): vals = all_f32[mask] bin_lo[b] = vals.min() bin_hi[b] = vals.max() - - # For any bin with no valid fp16 values, fall back to midpoint empty = bin_lo > bin_hi for b in np.where(empty)[0]: - mid_key = (int(b) << 8) | 0x80 - val = float(_key_to_fp16(mid_key)) + val = float(_key_to_fp16((int(b) << 8) | 0x80)) bin_lo[b] = val bin_hi[b] = val - return bin_lo, bin_hi -def generate_remap_lut(mode: int, param: float) -> np.ndarray: - """Generate a 256-entry uint8 LUT that approximates a transform mode. - - For each of the 256 fp16 radix bins, compute the transform of the - bin's midpoint value, then linearly map transformed values to [0,255]. - The resulting LUT can be used with mode=1 (LUT CDF) infrastructure, - replacing expensive per-element transcendental math with a single - shared memory lookup. - - Args: - mode: TopKMappingMode (3=Power, 4=Log, 6=Asinh, 7=Log1p, 9=Erf, 10=Tanh) - param: power_exp/beta/alpha for the transform - - Returns: - lut: [256] uint8 array mapping original_bin -> remapped_bin - """ - bin_lo, bin_hi = build_bin_range_table() - midpoints = (bin_lo + bin_hi) / 2.0 # [256] float32 - - # Apply transform - if mode == 3: # power - transformed = np.sign(midpoints) * np.abs(midpoints) ** param - elif mode == 4: # log - transformed = np.sign(midpoints) * np.log(np.abs(midpoints) + 1.0) - elif mode == 6: # asinh - transformed = np.arcsinh(param * midpoints) - elif mode == 7: # log1p - transformed = np.sign(midpoints) * np.log1p(param * np.abs(midpoints)) - elif mode == 9: # erf - from scipy.special import erf - transformed = erf(param * midpoints) - elif mode == 10: # tanh - transformed = np.tanh(param * midpoints) - else: - # Identity fallback - transformed = midpoints.copy() - - # Handle NaN/Inf from edge cases - transformed = np.nan_to_num(transformed, nan=0.0, posinf=0.0, neginf=0.0) - - # Linear map to [0, 255] - tmin, tmax = transformed.min(), transformed.max() - if tmax > tmin: - lut = np.clip(((transformed - tmin) / (tmax - tmin) * 255), 0, 255).astype(np.uint8) - else: - lut = np.full(256, 128, dtype=np.uint8) - - return lut - - -def scores_from_histogram( - histogram: np.ndarray, - total_pages: int, - device: str = "cuda", -) -> torch.Tensor: - """Generate score tensor matching a real bin distribution. - - For each sampled bin, generates a uniform random fp16 value within the - bin's actual value range (not just the midpoint), so that mapped transforms - see diverse input values. - - Args: - histogram: [256] aggregated bin counts from calibration - total_pages: number of score entries to generate - device: torch device - - Returns: - scores: [total_pages, 1, 1] bfloat16 tensor - """ - bin_lo, bin_hi = build_bin_range_table() - - # Normalize histogram to probability distribution +def _scores_from_histogram(histogram: np.ndarray, total_pages: int, device="cuda") -> torch.Tensor: + bin_lo, bin_hi = _build_bin_range_table() counts = histogram.astype(np.float64) total = counts.sum() if total == 0: return torch.zeros(total_pages, 1, 1, dtype=torch.bfloat16, device=device) probs = counts / total - - # Sample bin indices according to the real distribution bin_indices = np.random.choice(256, size=total_pages, p=probs) - - # Uniform random within each bin's fp16 range lo = bin_lo[bin_indices] hi = bin_hi[bin_indices] rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) scores_f32 = lo + rand * (hi - lo) + return torch.from_numpy(scores_f32).to(torch.bfloat16).reshape(total_pages, 1, 1).to(device) + - # Convert float32 -> bfloat16 tensor - scores = torch.from_numpy(scores_f32).to(torch.bfloat16) - return scores.reshape(total_pages, 1, 1).to(device) - - -def make_real_inputs( - batch_size: int, - num_kv_heads: int, - seq_len: int, - page_size: int, - topk_val: int, - reserved_bos: int, - reserved_eos: int, - histogram: np.ndarray, - device: str = "cuda", -) -> dict: - """Build CSR-formatted inputs with scores matching a real histogram.""" - eff_batch_size = batch_size * num_kv_heads - num_pages_per_seg = math.ceil(seq_len / page_size) - total_dense_pages = eff_batch_size * num_pages_per_seg - sparse_per_seg = min(topk_val + reserved_bos + reserved_eos, num_pages_per_seg) - total_sparse_pages = eff_batch_size * sparse_per_seg +def _make_real_inputs(args, histogram: np.ndarray) -> dict: + eff_bs = args.batch_size * args.num_kv_heads + num_pages_per_seg = math.ceil(args.seq_len / args.page_size) + total_dense = eff_bs * num_pages_per_seg + sparse_per_seg = min(args.topk_val + args.reserved_bos + args.reserved_eos, num_pages_per_seg) dense_kv_indptr = torch.arange( - 0, (eff_batch_size + 1) * num_pages_per_seg, num_pages_per_seg, - dtype=torch.int32, device=device, + 0, (eff_bs + 1) * num_pages_per_seg, num_pages_per_seg, + dtype=torch.int32, device="cuda", ) sparse_kv_indptr = torch.arange( - 0, (eff_batch_size + 1) * sparse_per_seg, sparse_per_seg, - dtype=torch.int32, device=device, + 0, (eff_bs + 1) * sparse_per_seg, sparse_per_seg, + dtype=torch.int32, device="cuda", ) - dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) - sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) - - x = scores_from_histogram(histogram, total_dense_pages, device=device) + dense_kv_indices = torch.arange(total_dense, dtype=torch.int32, device="cuda") + sparse_kv_indices = torch.zeros(eff_bs * sparse_per_seg, dtype=torch.int32, device="cuda") + x = _scores_from_histogram(histogram, total_dense) return { "x": x, @@ -256,402 +133,240 @@ def make_real_inputs( "sparse_kv_indptr": sparse_kv_indptr, "dense_kv_indices": dense_kv_indices, "sparse_kv_indices": sparse_kv_indices, - "eff_batch_size": eff_batch_size, + "eff_batch_size": eff_bs, "num_pages_per_seg": num_pages_per_seg, "sparse_per_seg": sparse_per_seg, } -def run_sweep(args) -> List[dict]: - """Run all (mode, hyperparam) combos and return ranked results.""" - results = [] - - # Load real histogram if provided - real_histogram = None - if args.real_histograms: - raw = np.load(args.real_histograms) # [num_segments, 256] - real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw # aggregate to [256] - - distributions = args.distributions - if real_histogram is not None: - distributions = ["real"] - - for dist in distributions: - if dist == "real": - inputs = make_real_inputs( - batch_size=args.batch_size, - num_kv_heads=args.num_kv_heads, - seq_len=args.seq_len, - page_size=args.page_size, - topk_val=args.topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - histogram=real_histogram, - ) - else: - inputs = make_topk_inputs( - batch_size=args.batch_size, - num_kv_heads=args.num_kv_heads, - seq_len=args.seq_len, - page_size=args.page_size, - topk_val=args.topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - score_dtype=torch.bfloat16, - distribution=dist, - ) - - eff_bs = inputs["eff_batch_size"] - - def evaluate(mode: int, power: float, label: str, noscale: bool = False, - lut_tensor=None): - hists = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") - topk_profile_histogram( - inputs["x"], - inputs["dense_kv_indptr"], - hists, - eff_bs, - args.reserved_bos, - args.reserved_eos, - mode, - power, - lut_tensor, # lut - None, # quantiles - noscale, - ) - torch.cuda.synchronize() - stats = compute_histogram_stats(hists) - result = { - "label": label, - "mode": mode, - "mode_name": MODE_NAMES.get(mode, f"m{mode}"), - "param": power, - "noscale": noscale, - "distribution": dist, - "gini": stats["gini"], - "max_mean_ratio": stats["max_mean_ratio"], - "num_nonzero_bins": stats["num_nonzero_bins"], - } - - # Counter-based metrics (Stage 2 cost analysis) - if args.counters: - inputs["sparse_kv_indices"].zero_() - counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") - topk_profile_counters( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - counter_buf, - eff_bs, - args.topk_val, - args.reserved_bos, - args.reserved_eos, - inputs["num_pages_per_seg"], - mode, - power, - lut_tensor, # lut - None, # quantiles - noscale, - ) - torch.cuda.synchronize() - c = counter_buf.float() - result["num_equal_mean"] = c[:, 2].mean().item() - result["remaining_k_mean"] = c[:, 3].mean().item() - result["refine_rounds_mean"] = c[:, 4].mean().item() - result["stage2_input_mean"] = c[:, 5].mean().item() - result["res_rate_mean"] = (c[:, 3] == 0).float().mean().item() - - return result - - # Baselines - for mode, (name, default_power) in BASELINES.items(): - r = evaluate(mode, default_power, f"m{mode}_{name}") - results.append(r) - - # Parametric sweep (scaled) - for mode, (param_name, values) in SWEEP_GRID.items(): - mname = MODE_NAMES[mode] - for val in values: - label = f"m{mode}_{mname}_{param_name}={val}" - r = evaluate(mode, val, label) - results.append(r) - - # Noscale sweep for parametric modes - for mode, (name, values) in NOSCALE_BASELINES.items(): - mname = MODE_NAMES[mode] - for val in values: - label = f"m{mode}_{mname}_noscale_{val}" - r = evaluate(mode, val, label, noscale=True) - results.append(r) - - # LUT approximation sweep: generate a LUT for each (mode, param) and - # evaluate via mode=1 (LUT CDF). This replaces per-element transcendentals - # with a single shared memory lookup. - if args.lut_sweep: - lut_modes = { - 3: [0.25, 0.5, 0.75], - 6: [0.5, 1.0, 2.0], - 7: [0.5, 1.0, 2.0], - 9: [0.5, 1.0, 2.0], - 10: [0.5, 1.0, 2.0], - } - for src_mode, params in lut_modes.items(): - src_name = MODE_NAMES[src_mode] - for p in params: - try: - lut_np = generate_remap_lut(src_mode, p) - lut_t = torch.from_numpy(lut_np).cuda() - label = f"lut_{src_name}_{p}" - # Evaluate as mode=1 (LUT CDF) with the generated LUT - r = evaluate(1, 0.5, label, lut_tensor=lut_t) - r["lut_source_mode"] = src_mode - r["lut_source_param"] = p - results.append(r) - except ImportError: - # scipy not available for erf - pass - - return results - - -def print_table(results: List[dict], show_latency: bool = False): - """Print ranked results as a formatted table.""" - has_counters = any("res_rate_mean" in r for r in results) - has_latency = any("full_kernel_ms" in r for r in results) - - # Primary ranking: by res_rate_mean (higher=better) if counters, else by gini (lower=better) - if has_counters: - ranked = sorted(results, key=lambda r: -r.get("res_rate_mean", 0.0)) - rank_label = "ranked by res_rate, higher=better" - else: - ranked = sorted(results, key=lambda r: r["gini"]) - rank_label = "ranked by Gini, lower=better" - - # Build header - cols = f"{'Rank':>4s} {'Label':<35s} {'Dist':<12s} {'Gini':>6s} {'Max/Mean':>8s} {'NZBins':>6s}" - if has_counters: - cols += f" {'ResRate':>7s} {'RemK':>5s} {'Rnds':>4s} {'S2In':>5s}" - if has_latency and show_latency: - cols += f" {'LatMs':>9s} {'LatRk':>5s}" - - print(f"\n{'=' * len(cols)}") - print(f"TopK Mapping Auto-Tune Results ({rank_label})") - print("=" * len(cols)) - print(cols) - print("-" * len(cols)) - - for i, r in enumerate(ranked): - noscale_tag = " [NS]" if r.get("noscale", False) else "" - line = ( - f"{i+1:4d} {r['label'] + noscale_tag:<35s} {r['distribution']:<12s} " - f"{r['gini']:6.3f} " - f"{r['max_mean_ratio']:8.2f} {r['num_nonzero_bins']:6d}" - ) - if has_counters: - rr = r.get("res_rate_mean", 0.0) - rk = r.get("remaining_k_mean", 0.0) - rnds = r.get("refine_rounds_mean", 0.0) - s2in = r.get("stage2_input_mean", 0.0) - line += f" {rr:7.3f} {rk:5.0f} {rnds:4.1f} {s2in:5.0f}" - if has_latency and show_latency: - lat = r.get("full_kernel_ms", float("nan")) - lat_rank = r.get("latency_rank", "-") - line += f" {lat:9.4f} {lat_rank:>5s}" if isinstance(lat_rank, str) else f" {lat:9.4f} {lat_rank:5d}" - print(line) - - print("=" * len(cols)) - if ranked: - best = ranked[0] - msg = ( - f"\nBest overall: {best['label']} (dist={best['distribution']}) " - f"— gini={best['gini']:.3f}, max/mean={best['max_mean_ratio']:.2f}" - ) - if has_counters: - msg += f", res_rate={best.get('res_rate_mean', 0):.3f}" - if "full_kernel_ms" in best: - msg += f", latency={best['full_kernel_ms']:.4f}ms" - print(msg) - - # If latency data available, also print best by latency - if has_latency and show_latency: - lat_ranked = sorted([r for r in results if "full_kernel_ms" in r], - key=lambda r: r["full_kernel_ms"]) - if lat_ranked: - best_lat = lat_ranked[0] - print( - f"Best by latency: {best_lat['label']} (dist={best_lat['distribution']}) " - f"— latency={best_lat['full_kernel_ms']:.4f}ms, gini={best_lat['gini']:.3f}" - ) - - # Per-mode best summary - mode_best = {} - for r in results: - m = r["mode"] - if has_counters: - is_better = m not in mode_best or r.get("res_rate_mean", 0) > mode_best[m].get("res_rate_mean", 0) - else: - is_better = m not in mode_best or r["gini"] < mode_best[m]["gini"] - if is_better: - mode_best[m] = r - - if mode_best: - print("\nBest per mode:") - for m in sorted(mode_best.keys()): - r = mode_best[m] - mname = MODE_NAMES.get(m, f"m{m}") - if m in SWEEP_GRID: - param_name = SWEEP_GRID[m][0] - param_str = f"{param_name}={r['param']}" - else: - param_str = "(baseline)" - ns_str = " noscale" if r.get("noscale", False) else "" - lat_str = f" latency={r['full_kernel_ms']:.4f}ms" if "full_kernel_ms" in r else "" - counter_str = f" res_rate={r.get('res_rate_mean', 0):.3f}" if has_counters else "" - print( - f" Mode {m:d} ({mname:>5s}{ns_str}): {param_str:<20s} " - f"gini={r['gini']:.3f} max/mean={r['max_mean_ratio']:.2f}{counter_str}{lat_str}" - ) - - -def latency_rerank(results: List[dict], args) -> List[dict]: - """Re-rank top Gini candidates by actual kernel latency.""" - # Sort by Gini, take top N - ranked = sorted(results, key=lambda r: r["gini"]) - finalists = ranked[:args.latency_top_n] - - print(f"\n--- Latency re-ranking: timing top {len(finalists)} Gini finalists ---") +# ---------- Latency-based evaluation ---------- - # Build inputs for latency measurement - real_histogram = None - if args.real_histograms: - raw = np.load(args.real_histograms) - real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw +def _time_fused(inputs, args, mode: int, power: float) -> dict: + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + inputs["sparse_kv_indices"].zero_() + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + call_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + return bench_kernel(topk_output_sglang_fused, call_args, + warmup=args.warmup, repeat=args.repeat) - if real_histogram is not None: - inputs = make_real_inputs( - batch_size=args.batch_size, - num_kv_heads=args.num_kv_heads, - seq_len=args.seq_len, - page_size=args.page_size, - topk_val=args.topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - histogram=real_histogram, - ) - else: - inputs = make_topk_inputs( - batch_size=args.batch_size, - num_kv_heads=args.num_kv_heads, - seq_len=args.seq_len, - page_size=args.page_size, - topk_val=args.topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - score_dtype=torch.bfloat16, - distribution="normal", - ) +def _collect_diagnostics(inputs, args, mode: int, power: float) -> dict: + """Optional distribution/counter stats for reporting only (post-timing).""" eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] + diag = {} + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + + if args.collect_stats: + hist = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], inputs["dense_kv_indptr"], hist, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, lut_t, q_t, + ) + torch.cuda.synchronize() + diag.update(compute_histogram_stats(hist)) - for r in finalists: + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") inputs["sparse_kv_indices"].zero_() - # For LUT-generated entries, regenerate the LUT tensor - lut_tensor = None - if "lut_source_mode" in r: - lut_np = generate_remap_lut(r["lut_source_mode"], r["lut_source_param"]) - lut_tensor = torch.from_numpy(lut_np).cuda() - call_args = ( + topk_profile_counters( inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + counter_buf, eff_bs, args.topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, - r["mode"], - r["param"], - lut_tensor, # lut - None, # quantiles - r.get("noscale", False), + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + c = counter_buf.float() + diag["threshold_bin_mean"] = c[:, 0].mean().item() + diag["num_equal_mean"] = c[:, 2].mean().item() + diag["refine_rounds_mean"] = c[:, 4].mean().item() + + return diag + + +def _run_sweep(args, inputs, dist_label: str) -> List[dict]: + results = [] + + # Baselines: time them but their param is fixed. + for mode, power in BASELINES: + lat = _time_fused(inputs, args, mode, power) + entry = { + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param_name": "(baseline)", + "param": power, + "distribution": dist_label, + "latency_ms": lat["mean_ms"], + "latency_median_ms": lat["median_ms"], + "latency_min_ms": lat["min_ms"], + } + entry.update(_collect_diagnostics(inputs, args, mode, power)) + results.append(entry) + print( + f" mode={mode:>2d} ({MODE_NAMES[mode]:>5s}) baseline " + f" latency={lat['mean_ms']:.4f} ms" ) - latency = bench_kernel(topk_output_sglang, call_args, - warmup=10, repeat=args.latency_repeat) - r["full_kernel_ms"] = latency["mean_ms"] - print(f" {r['label']:<35s} gini={r['gini']:.3f} latency={latency['mean_ms']:.4f}ms") - # Re-rank finalists by latency - finalists.sort(key=lambda r: r["full_kernel_ms"]) - for i, r in enumerate(finalists): - r["latency_rank"] = i + 1 - r["gini_rank"] = next(j+1 for j, x in enumerate(ranked) if x is r) + # Parametric sweep, one (mode, param) combo at a time. + for mode, values in SWEEP_GRID.items(): + pname = PARAM_NAME[mode] + for val in values: + lat = _time_fused(inputs, args, mode, float(val)) + entry = { + "mode": mode, + "mode_name": MODE_NAMES.get(mode, f"m{mode}"), + "param_name": pname, + "param": float(val), + "distribution": dist_label, + "latency_ms": lat["mean_ms"], + "latency_median_ms": lat["median_ms"], + "latency_min_ms": lat["min_ms"], + } + entry.update(_collect_diagnostics(inputs, args, mode, float(val))) + results.append(entry) + print( + f" mode={mode:>2d} ({MODE_NAMES[mode]:>5s}) {pname}={val:<6.3f} " + f" latency={lat['mean_ms']:.4f} ms" + ) return results -def main(): - parser = argparse.ArgumentParser( - description="Auto-tune TopK mapping hyperparameters" +def _print_ranked(results: List[dict]) -> None: + ranked = sorted(results, key=lambda r: r["latency_ms"]) + header = ( + f"{'Rank':>4s} {'Mode':<12s} {'Param':<14s} {'Dist':<10s} {'Latency (ms)':>14s}" ) + print("\n" + "=" * len(header)) + print("TopK auto-tune results (ranked by measured kernel latency, lower is better)") + print("=" * len(header)) + print(header) + print("-" * len(header)) + for i, r in enumerate(ranked): + param_str = f"{r['param_name']}={r['param']}" if r["param_name"] != "(baseline)" else "(baseline)" + print( + f"{i + 1:4d} {r['mode_name']:<12s} {param_str:<14s} " + f"{r['distribution']:<10s} {r['latency_ms']:14.4f}" + ) + print("=" * len(header)) + + # Best per mode. + best: Dict[int, dict] = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + print("\nBest per mode (by latency):") + for m in sorted(best.keys()): + r = best[m] + param_str = f"{r['param_name']}={r['param']}" if r["param_name"] != "(baseline)" else "(baseline)" + print( + f" mode {m:>2d} ({r['mode_name']:>5s}): {param_str:<16s} " + f"latency={r['latency_ms']:.4f} ms" + ) + + +def main(): + parser = argparse.ArgumentParser("TopK mapping hyperparameter auto-tuner (latency-driven)") parser.add_argument("--batch-size", type=int, default=4) - parser.add_argument("--seq-len", type=int, default=4096) - parser.add_argument("--topk-val", type=int, default=30) - parser.add_argument("--num-kv-heads", type=int, default=2) + parser.add_argument("--num-kv-heads", type=int, default=8) + parser.add_argument("--seq-len", type=int, default=65536) + parser.add_argument("--topk-val", type=int, default=2048) parser.add_argument("--page-size", type=int, default=16) parser.add_argument("--reserved-bos", type=int, default=1) parser.add_argument("--reserved-eos", type=int, default=2) - parser.add_argument( - "--distributions", nargs="+", - default=["normal"], - help="Score distributions for synthetic data (ignored when --real-histograms is set)", - ) - parser.add_argument( - "--real-histograms", type=str, default=None, - help="Path to raw_histograms.npy from calibration. When set, auto-tunes on " - "real score distribution instead of synthetic data.", - ) - parser.add_argument( - "--output-json", type=str, default=None, - help="Save results to JSON file", - ) - parser.add_argument("--latency-rerank", action="store_true", - help="Re-rank top Gini finalists by actual kernel latency") - parser.add_argument("--latency-top-n", type=int, default=10, - help="Number of Gini finalists to re-rank by latency (default: 10)") - parser.add_argument("--latency-repeat", type=int, default=50, - help="Kernel timing repetitions for latency measurement (default: 50)") - parser.add_argument("--counters", action="store_true", - help="Collect counter-based metrics (Stage 2 cost analysis) for each config") - parser.add_argument("--lut-sweep", action="store_true", - help="Generate and evaluate LUT approximations for parametric transform modes") + parser.add_argument("--distributions", type=str, nargs="+", + default=["normal"], + help="Synthetic distributions when --real-histograms is not set.") + parser.add_argument("--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibration.") + parser.add_argument("--warmup", type=int, default=20) + parser.add_argument("--repeat", type=int, default=100) + parser.add_argument("--collect-stats", action="store_true", + help="Also collect histogram + counter diagnostics (post-timing, no cost).") + parser.add_argument("--output-json", type=str, default=None) + parser.add_argument("--lut-path", type=str, default=None, + help="Path to .npy uint8[256] LUT for MAPPING_LUT_CDF (mode 1).") + parser.add_argument("--quantiles-path", type=str, default=None, + help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") args = parser.parse_args() - source = f"real ({args.real_histograms})" if args.real_histograms else f"synthetic ({args.distributions})" - print(f"Auto-tuning TopK mapping hyperparameters") - print(f" batch_size={args.batch_size}, seq_len={args.seq_len}, " - f"topk_val={args.topk_val}, num_kv_heads={args.num_kv_heads}") - print(f" score source: {source}") - n_parametric = sum(len(v) for _, v in SWEEP_GRID.values()) - n_baselines = len(BASELINES) - n_dists = 1 if args.real_histograms else len(args.distributions) - print(f" sweep: {n_parametric} parametric + {n_baselines} baselines " - f"= {n_parametric + n_baselines} combos x {n_dists} dists") + args._mapping_lut = None + args._mapping_quantiles = None + # Include modes 1/2 as baselines when calibration tables are provided. + if args.lut_path: + lut_np = np.load(args.lut_path).astype(np.uint8) + args._mapping_lut = torch.from_numpy(lut_np).cuda() + BASELINES.append((1, 0.5)) + print(f"[autotune] loaded LUT from {args.lut_path}") + if args.quantiles_path: + q_np = np.load(args.quantiles_path).astype(np.float32) + args._mapping_quantiles = torch.from_numpy(q_np).cuda() + BASELINES.append((2, 0.5)) + print(f"[autotune] loaded quantiles from {args.quantiles_path}") + + real_histogram: Optional[np.ndarray] = None + if args.real_histograms: + raw = np.load(args.real_histograms) + real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw - results = run_sweep(args) + all_results: List[dict] = [] - if args.latency_rerank: - results = latency_rerank(results, args) + if real_histogram is not None: + inputs = _make_real_inputs(args, real_histogram) + print("\n=== Latency sweep on REAL distribution " + f"(batch={args.batch_size} heads={args.num_kv_heads} seq={args.seq_len} topk={args.topk_val}) ===") + all_results += _run_sweep(args, inputs, "real") + else: + for dist in args.distributions: + inputs = make_topk_inputs( + batch_size=args.batch_size, + num_kv_heads=args.num_kv_heads, + seq_len=args.seq_len, + page_size=args.page_size, + topk_val=args.topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=dist, + ) + print(f"\n=== Latency sweep on synthetic dist={dist} ===") + all_results += _run_sweep(args, inputs, dist) - print_table(results, show_latency=args.latency_rerank) + _print_ranked(all_results) if args.output_json: with open(args.output_json, "w") as f: - json.dump(results, f, indent=2) + json.dump(all_results, f, indent=2) print(f"\nResults saved to {args.output_json}") diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index a913bde5..4bd5becb 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -1,36 +1,43 @@ """ TopK kernel benchmarking suite. -Measures kernel-level latency for the three topk variants (naive/CUB, -sglang with mapping modes) across configurable grid of batch sizes, -sequence lengths, topk values, and KV head counts. - -Usage: - python benchmarking/bench_topk.py --batch-sizes 4 8 --seq-lens 2048 4096 --topk-vals 30 --num-kv-heads 2 --repeat 50 +Lean rewrite after the remap-benchmark refactor. Exposes three public +helpers used by autotune_topk_mapping.py (make_topk_inputs, bench_kernel, +compute_histogram_stats) and a CLI with two modes: + + - default : time the baseline (unmapped) kernel and the fused + kernel across a grid of (mode, power, batch, seq_len, + topk_val, distribution) configs. + - --remap-bench: time baseline vs fused vs split-phase (remap-only + + unmapped-topk-on-remapped) and report threshold stats + from topk_profile_counters. """ import argparse import json import math import statistics -from typing import Dict, List, Optional +from typing import Dict, List import numpy as np import torch from vortex_torch_C import ( - topk_output, topk_output_sglang, topk_output_sglang_ori, topk_profile_histogram, - topk_profile_stage1, topk_profile_counters, + topk_output, + topk_output_sglang, # unmapped baseline + topk_output_sglang_fused, # fused remap + topk + topk_remap_only, # standalone remap + topk_profile_histogram, + topk_profile_counters, ) -# Canonical mapping mode names — used in logs, tables, and plots + MAPPING_MODE_NAMES = { 0: "None", - 1: "LUT CDF", + 1: "LUT_CDF", 2: "Quantile", 3: "Power", 4: "Log", - 5: "Index Cache", 6: "Asinh", 7: "Log1p", 8: "Trunc8", @@ -38,25 +45,30 @@ 10: "Tanh", 11: "Subtract", 13: "ExpStretch", - 14: "TopkWindow", } -MAPPING_MODE_FORMULAS = { - 0: "None (fp16 bucketing)", - 1: "LUT CDF (calibrated)", - 2: "Quantile (calibrated)", - 3: "Power: sign(x)*|x|^p", - 4: "Log: sign(x)*log(|x|+1)", - 5: "Index Cache", - 6: "Asinh: asinh(beta*x)", - 7: "Log1p: sign(x)*log1p(alpha*|x|)", - 8: "Trunc8: bf16 upper-8-bit bucketing", - 9: "Erf: erf(alpha*x)", - 10: "Tanh: tanh(alpha*x)", - 11: "Subtract: x - pivot (RadiK-style)", - 13: "ExpStretch: exp(alpha*x)", - 14: "TopkWindow: k-aware linear windowing", -} + +def _load_autotune_hparams(path: str) -> Dict[int, float]: + """Load per-mode best hyperparameters from an autotune_results.json. + + The JSON is produced by autotune_topk_mapping.py and contains a list of + {mode, param, latency_ms, ...} entries. For each mode we pick the entry + with the lowest measured latency and return {mode: best_param}. + + Modes with no parametric sweep (0=None, 4=Log) return a dummy 0.5; the + caller should override to taste. + """ + with open(path) as f: + data = json.load(f) + best: Dict[int, dict] = {} + for r in data: + m = r.get("mode") + lat = r.get("latency_ms") + if m is None or lat is None: + continue + if m not in best or lat < best[m]["latency_ms"]: + best[m] = r + return {m: float(r["param"]) for m, r in best.items()} def make_topk_inputs( @@ -71,7 +83,7 @@ def make_topk_inputs( distribution: str = "normal", device: str = "cuda", ) -> dict: - """Synthesize realistic CSR-formatted paged attention inputs.""" + """Synthesize CSR-formatted paged attention inputs for kernel timing.""" eff_batch_size = batch_size * num_kv_heads num_pages_per_seg = math.ceil(seq_len / page_size) total_dense_pages = eff_batch_size * num_pages_per_seg @@ -89,7 +101,6 @@ def make_topk_inputs( dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) - # Generate scores with the requested distribution if distribution == "normal": x = torch.randn(total_dense_pages, 1, 1, device=device) elif distribution == "lognormal": @@ -97,14 +108,11 @@ def make_topk_inputs( elif distribution == "uniform": x = torch.rand(total_dense_pages, 1, 1, device=device) elif distribution == "bucket_uniform": - # Uniform across all 256 fp16 radix buckets. - # Random uint16 bit patterns → interpret as fp16. - # Bucket = upper 8 bits of sign-flipped fp16, so random bits → uniform buckets. + # Uniform across all 256 fp16 radix buckets. Random uint16 bit + # patterns → interpret as fp16. NaN/Inf patterns collapse to ±0. raw_bits = torch.randint(0, 65536, (total_dense_pages,), dtype=torch.int32, device=device) - # Exclude fp16 NaN/Inf (exponent=31, i.e. |bits| >= 0x7C00) abs_bits = raw_bits & 0x7FFF - raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 # → ±0 - # Reinterpret int16 bits as fp16, then widen to float32 + raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1) else: raise ValueError(f"Unknown distribution: {distribution}") @@ -124,7 +132,7 @@ def make_topk_inputs( def bench_kernel(kernel_fn, args, warmup: int = 10, repeat: int = 100) -> dict: - """Time a kernel with CUDA events, return latency stats in ms.""" + """Time a kernel with CUDA events. Returns latency stats in ms.""" for _ in range(warmup): kernel_fn(*args) torch.cuda.synchronize() @@ -148,796 +156,304 @@ def bench_kernel(kernel_fn, args, warmup: int = 10, repeat: int = 100) -> dict: def compute_histogram_stats(histograms: torch.Tensor) -> dict: - """Compute bin distribution statistics from histogram tensor [B, 256].""" + """Bin distribution statistics from histogram tensor [B, 256].""" h = histograms.float() - # Aggregate across batch dimension h_sum = h.sum(dim=0) # [256] - nonzero_bins = h_sum[h_sum > 0] - if len(nonzero_bins) == 0: + nonzero = h_sum[h_sum > 0] + if len(nonzero) == 0: return { "max_mean_ratio": 0.0, "std": 0.0, "gini": 0.0, "num_nonzero_bins": 0, "entropy": 0.0, "effective_bins": 0.0, } - - mean_val = nonzero_bins.mean().item() - max_val = nonzero_bins.max().item() - std_val = nonzero_bins.std().item() if len(nonzero_bins) > 1 else 0.0 - - # Gini coefficient - sorted_bins = nonzero_bins.sort().values + mean_val = nonzero.mean().item() + max_val = nonzero.max().item() + std_val = nonzero.std().item() if len(nonzero) > 1 else 0.0 + sorted_bins = nonzero.sort().values n = len(sorted_bins) - index = torch.arange(1, n + 1, device=sorted_bins.device, dtype=torch.float32) - gini = (2.0 * (index * sorted_bins).sum() / (n * sorted_bins.sum()) - (n + 1) / n).item() - - # Shannon entropy (base-2) - p = nonzero_bins / nonzero_bins.sum() + idx = torch.arange(1, n + 1, device=sorted_bins.device, dtype=torch.float32) + gini = (2.0 * (idx * sorted_bins).sum() / (n * sorted_bins.sum()) - (n + 1) / n).item() + p = nonzero / nonzero.sum() entropy = -(p * p.log2()).sum().item() - # Effective number of bins: 2^entropy - effective_bins = 2 ** entropy - return { "max_mean_ratio": max_val / mean_val if mean_val > 0 else 0.0, "std": std_val, "gini": max(0.0, gini), - "num_nonzero_bins": int(len(nonzero_bins)), + "num_nonzero_bins": int(len(nonzero)), "entropy": entropy, - "effective_bins": effective_bins, + "effective_bins": 2 ** entropy, } -NUM_HISTOGRAM_BINS = 256 - +def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, power: float) -> dict: + """Run topk_profile_counters once and aggregate threshold-bin stats. -def _histogram_target_pages(pages_per_seg: int, min_samples_per_bin: int = 512) -> int: - """Compute adaptive page count for statistically reliable histograms. - - With 256 radix bins, each bin needs enough samples for stable gini / - max-mean statistics. Returns a total page count rounded up to a full - segment boundary so every segment contributes equally. + Profile kernel is invoked AFTER all latency measurements have finished, + so the counter writes never contaminate timing. """ - min_pages = min_samples_per_bin * NUM_HISTOGRAM_BINS - return math.ceil(min_pages / pages_per_seg) * pages_per_seg + eff_bs = inputs["eff_batch_size"] + counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") + inputs["sparse_kv_indices"].zero_() + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + topk_profile_counters( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + counter_buf, + eff_bs, + topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + c = counter_buf.float() + # Selected from threshold bin = topk_val - num_above (clamped >= 0). + sel_from_thr = (float(topk_val) - c[:, 1]).clamp(min=0.0) + return { + "threshold_bin_mean": c[:, 0].mean().item(), + "threshold_bin_max": c[:, 0].max().item(), + "num_above_mean": c[:, 1].mean().item(), + "threshold_bin_size_mean": c[:, 2].mean().item(), # NUM_EQUAL + "threshold_bin_size_max": c[:, 2].max().item(), + "selected_from_thr_mean": sel_from_thr.mean().item(), + "selected_from_thr_max": sel_from_thr.max().item(), + "refine_rounds_mean": c[:, 4].mean().item(), + } -def _load_autotune_powers(path: str) -> Dict[int, float]: - """Extract best per-mode power from autotune JSON. +def _resolve_hparam(args, mode: int) -> float: + """Pick the hyperparameter for a mode: autotune JSON wins, then --mapping-hparam.""" + if mode == 0: + return 0.5 # unused for MAPPING_NONE + hparams: Dict[int, float] = getattr(args, "_autotune_hparams", {}) or {} + if mode in hparams: + return hparams[mode] + return args.mapping_hparam - Ranks by res_rate_mean (higher=better) if present, else by gini (lower=better). - Returns {mode: best_power}, e.g. {3: 0.25, 6: 1.0, 7: 2.0}. - """ - with open(path) as f: - data = json.load(f) - has_res_rate = any("res_rate_mean" in r for r in data) +def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, + distribution, modes: List[int]) -> dict: + """Time baseline, fused, and split-phase for each mode at one config.""" + inputs = make_topk_inputs( + batch_size=batch_size, + num_kv_heads=num_kv_heads, + seq_len=seq_len, + page_size=args.page_size, + topk_val=topk_val, + reserved_bos=args.reserved_bos, + reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, + distribution=distribution, + ) + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + total_dense = inputs["x"].numel() + + # Baseline: unmapped topk. + baseline_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + baseline = bench_kernel(topk_output_sglang, baseline_args, args.warmup, args.repeat) + + # Pre-allocate the float32 buffer used for the split-phase (remap → baseline). + remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(inputs["x"].shape) + + config = { + "batch_size": batch_size, + "num_kv_heads": num_kv_heads, + "seq_len": seq_len, + "topk_val": topk_val, + "distribution": distribution, + "pages_per_seg": pages_per_seg, + "baseline_ms": baseline["mean_ms"], + "modes": [], + } - best: Dict[int, dict] = {} - for r in data: - m = r.get("mode") - if m not in (3, 6, 7, 9, 10, 13, 14): - continue - if has_res_rate: - score = r.get("res_rate_mean", 0.0) - is_better = m not in best or score > best[m]["_score"] - else: - score = r.get("gini", 1.0) - is_better = m not in best or score < best[m]["_score"] - if is_better: - best[m] = {"param": r["param"], "_score": score} + for mode in modes: + power = _resolve_hparam(args, mode) + + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None + fused_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, lut_t, q_t, + ) + inputs["sparse_kv_indices"].zero_() + fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) + + # Split-phase timing: first the standalone remap, then the unmapped + # topk on the remapped buffer. + remap_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, + ) + remap_only = bench_kernel(topk_remap_only, remap_args, args.warmup, args.repeat) + + split_topk_args = ( + remapped, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + # Run remap once so the buffer is populated for warmup of topk-on-remapped. + topk_remap_only(*remap_args) + torch.cuda.synchronize() + inputs["sparse_kv_indices"].zero_() + split_topk = bench_kernel(topk_output_sglang, split_topk_args, args.warmup, args.repeat) + + # Counter collection is run AFTER all timing measurements for this mode + # so it cannot affect the timings. + stats = _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode, power) + + row = { + "mode": mode, + "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, + "remap_ms": remap_only["mean_ms"], + "topk_after_remap_ms": split_topk["mean_ms"], + "split_total_ms": remap_only["mean_ms"] + split_topk["mean_ms"], + "fused_ms": fused["mean_ms"], + **stats, + } + config["modes"].append(row) - return {m: v["param"] for m, v in best.items()} + return config -def _resolve_mode_hparam(args, mode: int) -> float: - """Return the power/beta/alpha for a parametric mapping mode. +def _print_remap_table(results: List[dict]) -> None: + header = ( + f"{'mode':<12s} {'remap_us':>9s} {'topk_us':>9s} {'split_us':>9s} " + f"{'fused_us':>9s} {'base_us':>9s} {'thr_bin':>7s} {'thr_size':>8s} {'sel_thr':>7s}" + ) + for cfg in results: + banner = ( + f"\n[batch={cfg['batch_size']} heads={cfg['num_kv_heads']} " + f"seq_len={cfg['seq_len']} topk={cfg['topk_val']} " + f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']}]" + ) + print(banner) + print(" Baseline: mapping_mode=0 (raw fp16 bucketing)") + print(header) + print("-" * len(header)) + base_us = cfg["baseline_ms"] * 1000.0 + for row in cfg["modes"]: + label = f"{row['mode_name']}(p={row['power']})" if row["mode"] != 0 else "None" + print( + f"{label:<12s} " + f"{row['remap_ms'] * 1000.0:9.2f} " + f"{row['topk_after_remap_ms'] * 1000.0:9.2f} " + f"{row['split_total_ms'] * 1000.0:9.2f} " + f"{row['fused_ms'] * 1000.0:9.2f} " + f"{base_us:9.2f} " + f"{row['threshold_bin_mean']:7.1f} " + f"{row['threshold_bin_size_mean']:8.1f} " + f"{row['selected_from_thr_mean']:7.1f}" + ) - Priority: per-mode CLI flag > autotune JSON > global --mapping-power. - """ - per_mode_flag = {3: args.mapping_hparam_3, 6: args.mapping_hparam_6, 7: args.mapping_hparam_7, - 9: getattr(args, 'mapping_hparam_9', None), 10: getattr(args, 'mapping_hparam_10', None), - 13: getattr(args, 'mapping_hparam_13', None), 14: getattr(args, 'mapping_hparam_14', None)} - if mode in per_mode_flag and per_mode_flag[mode] is not None: - return per_mode_flag[mode] - if hasattr(args, "_autotune_powers") and mode in args._autotune_powers: - return args._autotune_powers[mode] - return args.mapping_hparam +def _run_remap_bench(args) -> None: + modes = [int(m) for m in args.mapping_modes] + if 0 not in modes: + modes = [0] + modes -def _run_subphase_profiling(subphase_modes, inputs, eff_bs, topk_val, - pages_per_seg, args, mapping_lut, mapping_quantiles): - """Run sub-phase profiling (histogram_only + stage1_full) for each mode. + results = [] + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + for dist in args.distributions: + cfg = _remap_bench_one_config( + args, bs, heads, seq_len, topk_val, dist, modes, + ) + results.append(cfg) - For topk <= 512, runs inline. For topk > 512, runs each mode in a - separate subprocess to avoid CUDA shared memory exhaustion from - accumulated kernel template registrations. - """ - import subprocess, sys, tempfile, os - - for kernel_name, s1_mode, s1_power, s1_noscale, result in subphase_modes: - s1_lut = mapping_lut if s1_mode == 1 else None - s1_q = mapping_quantiles if s1_mode == 2 else None - - if topk_val <= 512: - # Inline: run directly in this process - hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") - hist_args = ( - inputs["x"], inputs["dense_kv_indptr"], hist_buf, eff_bs, - args.reserved_bos, args.reserved_eos, - s1_mode, s1_power, s1_lut, s1_q, s1_noscale, - ) - hist_result = bench_kernel(topk_profile_histogram, hist_args, args.warmup, args.repeat) - - inputs["sparse_kv_indices"].zero_() - stage1_args = ( - inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], inputs["sparse_kv_indices"], - eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, - s1_mode, s1_power, s1_lut, s1_q, s1_noscale, - ) - stage1_result = bench_kernel(topk_profile_stage1, stage1_args, args.warmup, args.repeat) - else: - # Subprocess: fresh CUDA context per mode to avoid shared memory exhaustion - script = f""" -import torch, json, sys -sys.path.insert(0, '{os.path.dirname(os.path.abspath(__file__))}') -from vortex_torch_C import topk_profile_histogram, topk_profile_stage1 -from bench_topk import make_topk_inputs, bench_kernel - -inputs = make_topk_inputs( - batch_size={inputs['x'].shape[0] // (eff_bs // (inputs['x'].shape[0] if eff_bs == inputs['x'].shape[0] else 1)) if False else 1}, - num_kv_heads=1, seq_len={pages_per_seg * 16}, - page_size=16, topk_val={topk_val}, - reserved_bos={args.reserved_bos}, reserved_eos={args.reserved_eos}, - score_dtype=torch.bfloat16, distribution="normal", -) -eff_bs = {eff_bs} -# Recreate inputs with correct eff_bs -inputs = make_topk_inputs( - batch_size={eff_bs // max(1, eff_bs // pages_per_seg) if False else eff_bs}, - num_kv_heads=1, seq_len={pages_per_seg * 16}, - page_size=16, topk_val={topk_val}, - reserved_bos={args.reserved_bos}, reserved_eos={args.reserved_eos}, - score_dtype=torch.bfloat16, distribution="normal", -) -eff_bs = inputs["eff_batch_size"] - -hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") -hist_result = bench_kernel(topk_profile_histogram, - (inputs["x"], inputs["dense_kv_indptr"], hist_buf, eff_bs, - {args.reserved_bos}, {args.reserved_eos}, {s1_mode}, {s1_power}, - None, None, {s1_noscale}), 5, {args.repeat}) - -inputs["sparse_kv_indices"].zero_() -stage1_result = bench_kernel(topk_profile_stage1, - (inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], inputs["sparse_kv_indices"], - eff_bs, {topk_val}, {args.reserved_bos}, {args.reserved_eos}, - inputs["num_pages_per_seg"], {s1_mode}, {s1_power}, - None, None, {s1_noscale}), 5, {args.repeat}) - -print(json.dumps({{"hist": hist_result, "stage1": stage1_result}})) -""" - try: - proc = subprocess.run( - [sys.executable, "-c", script], - capture_output=True, text=True, timeout=60, - env={**os.environ, "PYTHONPATH": os.path.dirname(os.path.abspath(__file__)) + "/.."}) - if proc.returncode == 0: - data = json.loads(proc.stdout.strip().split("\n")[-1]) - hist_result = data["hist"] - stage1_result = data["stage1"] - else: - # Subprocess failed — skip sub-phase for this mode - continue - except Exception: - continue - - result['histogram_only_mean_ms'] = hist_result['mean_ms'] - result['histogram_only_median_ms'] = hist_result['median_ms'] - result['stage1_full_mean_ms'] = stage1_result['mean_ms'] - result['stage1_full_median_ms'] = stage1_result['median_ms'] - result['route_overhead_mean_ms'] = stage1_result['mean_ms'] - hist_result['mean_ms'] - result['route_overhead_median_ms'] = stage1_result['median_ms'] - hist_result['median_ms'] - result['stage2_refine_mean_ms'] = result['mean_ms'] - stage1_result['mean_ms'] - result['stage2_refine_median_ms'] = result['median_ms'] - stage1_result['median_ms'] - - -def run_benchmark(args) -> List[dict]: - """Run the full benchmark sweep and return results.""" - # Load autotune results if provided - if args.autotune_json: - args._autotune_powers = _load_autotune_powers(args.autotune_json) - print(f"Loaded autotune best powers: {args._autotune_powers}") - else: - args._autotune_powers = {} - - dtype_map = {"bfloat16": torch.bfloat16, "float32": torch.float32} - score_dtype = dtype_map[args.score_dtype] - - # Load real histogram if provided - real_histogram = None - _scores_from_histogram = None - if args.real_histograms: - from autotune_topk_mapping import scores_from_histogram - _scores_from_histogram = scores_from_histogram - raw = np.load(args.real_histograms) - real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw - - # Extend distributions with "real" if calibration data is provided - distributions = list(args.distributions) - if real_histogram is not None: - distributions.append("real") - args.distributions = distributions - - # Print GPU info - gpu_name = torch.cuda.get_device_name(0) - gpu_props = torch.cuda.get_device_properties(0) - print(f"TopK Kernel Benchmark Results") - print(f"GPU: {gpu_name} | SM count: {gpu_props.multi_processor_count}") - print(f"Score dtype: {args.score_dtype} | Warmup: {args.warmup} | Repeat: {args.repeat}") - print(f"Radix bits: {args.radix_bits} ({1 << args.radix_bits} bins) | Sample stride: {args.sample_stride}") - print("=" * 90) - - # Load optional LUT / quantiles - mapping_lut = None - mapping_quantiles = None - if args.lut_path: - lut_np = np.load(args.lut_path).astype(np.uint8) - mapping_lut = torch.from_numpy(lut_np).cuda() - if args.quantiles_path: - q_np = np.load(args.quantiles_path).astype(np.float32) - mapping_quantiles = torch.from_numpy(q_np).cuda() - - # Build kernel list - all_kernels = { - "naive": "naive", - "sglang_ori": "sglang_ori", - "sglang_m0": "sglang_m0", - "sglang_scale": "sglang_scale", # mode 3 with p=1.0 (identity + linear auto-range scaling) - "sglang_m3": "sglang_m3", - "sglang_m3_noscale": "sglang_m3_noscale", - "sglang_m4": "sglang_m4", - "sglang_m6": "sglang_m6", - "sglang_m6_noscale": "sglang_m6_noscale", - "sglang_m7": "sglang_m7", - "sglang_m7_noscale": "sglang_m7_noscale", - "sglang_m8": "sglang_m8", - "sglang_m9": "sglang_m9", - "sglang_m9_noscale": "sglang_m9_noscale", - "sglang_m10": "sglang_m10", - "sglang_m10_noscale": "sglang_m10_noscale", - "sglang_m11": "sglang_m11", - "sglang_m13": "sglang_m13", - "sglang_m13_noscale": "sglang_m13_noscale", - "sglang_m14": "sglang_m14", - } - if mapping_lut is not None: - all_kernels["sglang_m1"] = "sglang_m1" - if mapping_quantiles is not None: - all_kernels["sglang_m2"] = "sglang_m2" - - if args.filter_kernels: - # Validate: if the user explicitly requested sglang_m1 or sglang_m2 but - # the required calibration file was not provided, fail loudly instead of - # silently skipping these modes. - if "sglang_m1" in args.filter_kernels and "sglang_m1" not in all_kernels: - raise RuntimeError( - "sglang_m1 (LUT CDF) was requested in --filter-kernels but no " - "--lut-path was provided. Mode 1 requires a calibrated LUT file " - "(lut.npy from calibrate_topk.py). Either supply --lut-path or " - "remove sglang_m1 from --filter-kernels." - ) - if "sglang_m2" in args.filter_kernels and "sglang_m2" not in all_kernels: - raise RuntimeError( - "sglang_m2 (Quantile) was requested in --filter-kernels but no " - "--quantiles-path was provided. Mode 2 requires a calibrated " - "quantiles file (quantiles.npy from calibrate_topk.py). Either " - "supply --quantiles-path or remove sglang_m2 from --filter-kernels." - ) - all_kernels = {k: v for k, v in all_kernels.items() if k in args.filter_kernels} + _print_remap_table(results) - # Naive kernel only supports bf16 - if score_dtype != torch.bfloat16 and "naive" in all_kernels: - print(f"Note: naive kernel only supports bfloat16, skipping for {args.score_dtype}") - del all_kernels["naive"] + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to {args.output_json}") - all_results = [] +def _run_latency_sweep(args) -> None: + """Simple baseline-vs-fused latency sweep (no split-phase, no counters).""" + modes = [int(m) for m in args.mapping_modes] + results = [] for bs in args.batch_sizes: - for seq_len in args.seq_lens: - for topk_val in args.topk_vals: - for num_kv_heads in args.num_kv_heads: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: for dist in args.distributions: - if dist == "real" and real_histogram is not None: - inputs = make_topk_inputs( - batch_size=bs, - num_kv_heads=num_kv_heads, - seq_len=seq_len, - page_size=args.page_size, - topk_val=topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - score_dtype=score_dtype, - distribution="normal", - ) - # Replace scores with real-distribution scores - total_dense = inputs["eff_batch_size"] * inputs["num_pages_per_seg"] - inputs["x"] = _scores_from_histogram( - real_histogram, total_dense, device="cuda", - ) - else: - inputs = make_topk_inputs( - batch_size=bs, - num_kv_heads=num_kv_heads, - seq_len=seq_len, - page_size=args.page_size, - topk_val=topk_val, - reserved_bos=args.reserved_bos, - reserved_eos=args.reserved_eos, - score_dtype=score_dtype, - distribution=dist, - ) - + inputs = make_topk_inputs( + batch_size=bs, num_kv_heads=heads, seq_len=seq_len, + page_size=args.page_size, topk_val=topk_val, + reserved_bos=args.reserved_bos, reserved_eos=args.reserved_eos, + score_dtype=torch.bfloat16, distribution=dist, + ) eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] - - config_str = ( - f"bs={bs} | seq={seq_len} | topk={topk_val} | " - f"heads={num_kv_heads} | pages/seg={pages_per_seg} | dist={dist}" - ) - print(f"\n{config_str}") - - config_results = { - "batch_size": bs, - "seq_len": seq_len, - "topk_val": topk_val, - "num_kv_heads": num_kv_heads, - "distribution": dist, - "eff_batch_size": eff_bs, - "pages_per_seg": pages_per_seg, - "kernels": {}, - } - - # Collect all kernel results first, then print sorted by latency - kernel_entries = [] # [(label, kernel_name, result)] - - for kernel_name in all_kernels: - # Reset sparse indices each run + row_modes = [] + for mode in modes: + power = _resolve_hparam(args, mode) inputs["sparse_kv_indices"].zero_() - - if kernel_name == "naive": - call_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indptr"], - inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - ) - result = bench_kernel(topk_output, call_args, args.warmup, args.repeat) - elif kernel_name == "sglang_ori": + if mode == 0: + call = topk_output_sglang call_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], + inputs["x"], inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - args.radix_bits, + eff_bs, topk_val, + args.reserved_bos, args.reserved_eos, pages_per_seg, ) - result = bench_kernel(topk_output_sglang_ori, call_args, args.warmup, args.repeat) - elif kernel_name == "sglang_scale": - call_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - 3, # mode 3 (power) - 1.0, # p=1.0 → identity - None, - None, - False, - args.sample_stride, - args.radix_bits, - ) - result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) else: - mode_str = kernel_name.split("_m")[1] - mode = int(mode_str.split("_")[0]) - is_noscale = kernel_name.endswith("_noscale") - extra_kwargs = {} - if mode == 1: - extra_kwargs["mapping_lut"] = mapping_lut - elif mode == 2: - extra_kwargs["mapping_quantiles"] = mapping_quantiles - - if mode in (3, 6, 7, 9, 10, 13, 14): - power = _resolve_mode_hparam(args, mode) - else: - power = 0.5 - + call = topk_output_sglang_fused + lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None + q_t = getattr(args, "_mapping_quantiles", None) if mode == 2 else None call_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - mode, - power, - extra_kwargs.get("mapping_lut", None), - extra_kwargs.get("mapping_quantiles", None), - is_noscale, - args.sample_stride, - args.radix_bits, - ) - result = bench_kernel(topk_output_sglang, call_args, args.warmup, args.repeat) - - # Build label - if kernel_name == "naive": - label = "naive" - elif kernel_name == "sglang_ori": - label = "sglang Ori (no remap)" - elif kernel_name == "sglang_scale": - label = "sglang Scale Only (p=1.0)" - else: - m_str = kernel_name.split("_m")[1] - m = int(m_str.split("_")[0]) - noscale_suffix = " noscale" if kernel_name.endswith("_noscale") else "" - mname = MAPPING_MODE_NAMES.get(m, f'm{m}') - if m in (3, 6, 7, 9, 10, 13, 14): - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[m] - label = f"sglang {mname} ({pname}={_resolve_mode_hparam(args, m)}){noscale_suffix}" - else: - label = f"sglang {mname}{noscale_suffix}" - - # Counter collection (runs separately from sub-phase profiling) - if kernel_name not in ("naive",) and args.counters: - if kernel_name in ("sglang_ori",): - c_mode, c_power, c_lut, c_q, c_noscale = 0, 0.5, None, None, False - elif kernel_name == "sglang_scale": - c_mode, c_power, c_lut, c_q, c_noscale = 3, 1.0, None, None, False - else: - c_mode_str = kernel_name.split("_m")[1] - c_mode = int(c_mode_str.split("_")[0]) - c_noscale = kernel_name.endswith("_noscale") - c_power = _resolve_mode_hparam(args, c_mode) if c_mode in (3,6,7,9,10,13,14) else 0.5 - c_lut = mapping_lut if c_mode == 1 else None - c_q = mapping_quantiles if c_mode == 2 else None - inputs["sparse_kv_indices"].zero_() - counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") - counter_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], + inputs["x"], inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], inputs["dense_kv_indices"], inputs["sparse_kv_indices"], - counter_buf, - eff_bs, - topk_val, - args.reserved_bos, - args.reserved_eos, - pages_per_seg, - c_mode, - c_power, - c_lut, - c_q, - c_noscale, + eff_bs, topk_val, + args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, lut_t, q_t, ) - topk_profile_counters(*counter_args) - torch.cuda.synchronize() - c = counter_buf.float() - result['counters'] = { - 'threshold_bin_mean': c[:, 0].mean().item(), - 'num_above_mean': c[:, 1].mean().item(), - 'num_equal_mean': c[:, 2].mean().item(), - 'remaining_k_mean': c[:, 3].mean().item(), - 'refine_rounds_mean': c[:, 4].mean().item(), - 'stage2_input_mean': c[:, 5].mean().item(), - 'threshold_bin_max': c[:, 0].max().item(), - 'num_above_max': c[:, 1].max().item(), - 'num_equal_max': c[:, 2].max().item(), - 'remaining_k_max': c[:, 3].max().item(), - 'refine_rounds_max': c[:, 4].max().item(), - 'stage2_input_max': c[:, 5].max().item(), - } - - kernel_entries.append((label, kernel_name, result)) - config_results["kernels"][kernel_name] = result - - # Second pass: sub-phase profiling (histogram_only + stage1_full) - # Run in a subprocess to get a fresh CUDA context, avoiding - # shared memory exhaustion from accumulated kernel registrations. - subphase_modes = [] - for label, kernel_name, result in kernel_entries: - if kernel_name in ("naive", "sglang_ori"): - continue - if kernel_name == "sglang_scale": - s1_mode, s1_power, s1_noscale = 3, 1.0, False - else: - s1_mode_str = kernel_name.split("_m")[1] - s1_mode = int(s1_mode_str.split("_")[0]) - s1_noscale = kernel_name.endswith("_noscale") - s1_power = _resolve_mode_hparam(args, s1_mode) if s1_mode in (3,6,7,9,10,13,14) else 0.5 - subphase_modes.append((kernel_name, s1_mode, s1_power, s1_noscale, result)) - - if subphase_modes: - _run_subphase_profiling( - subphase_modes, inputs, eff_bs, topk_val, - pages_per_seg, args, mapping_lut, mapping_quantiles) - - # Print kernel results sorted by mean latency (ascending) - kernel_entries.sort(key=lambda e: e[2]['mean_ms']) - print(f" --- kernel latency (sorted by mean, ascending) ---") - for label, kernel_name, result in kernel_entries: + stats = bench_kernel(call, call_args, args.warmup, args.repeat) + row_modes.append({ + "mode": mode, "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, "mean_ms": stats["mean_ms"], + "median_ms": stats["median_ms"], + }) print( - f" {label:<40s}: " - f"mean={result['mean_ms']:.4f}ms " - f"median={result['median_ms']:.4f}ms " - f"\u00b1 {result['std_ms']:.4f}ms " - f"[min={result['min_ms']:.4f}, max={result['max_ms']:.4f}]" + f"bs={bs} h={heads} seq={seq_len} topk={topk_val} " + f"dist={dist} mode={mode:>2d} lat={stats['mean_ms']:.4f} ms" ) - if 'stage1_full_mean_ms' in result: - print( - f" {'Histogram only (map+hist)':<36s}: " - f"mean={result['histogram_only_mean_ms']:.4f}ms " - f"median={result['histogram_only_median_ms']:.4f}ms" - ) - print( - f" {'Stage1 full (hist+cumsum+route)':<36s}: " - f"mean={result['stage1_full_mean_ms']:.4f}ms " - f"median={result['stage1_full_median_ms']:.4f}ms" - ) - print( - f" {'Route overhead (cumsum+route)':<36s}: " - f"mean={result['route_overhead_mean_ms']:.4f}ms " - f"median={result['route_overhead_median_ms']:.4f}ms" - ) - print( - f" {'Stage2 (refine)':<36s}: " - f"mean={result['stage2_refine_mean_ms']:.4f}ms " - f"median={result['stage2_refine_median_ms']:.4f}ms" - ) - if 'counters' in result: - c = result['counters'] - print( - f" Counters: threshold_bin={c['threshold_bin_mean']:.0f} " - f"above={c['num_above_mean']:.0f} " - f"equal={c['num_equal_mean']:.0f} " - f"remaining_k={c['remaining_k_mean']:.0f} " - f"refine_rounds={c['refine_rounds_mean']:.1f} " - f"stage2_input={c['stage2_input_mean']:.0f}" - ) - - # Histogram analysis — uses the SAME inputs as the main benchmark - # so histogram CSV and counters reflect the same data. - if args.histogram: - hist_inputs = inputs - hist_eff_bs = eff_bs - current_pages = eff_bs * pages_per_seg - print(f" histogram dataset : {current_pages} pages (same as benchmark)") - - # Raw unmapped histogram - histograms = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - topk_profile_histogram( - hist_inputs["x"], - hist_inputs["dense_kv_indptr"], - histograms, - hist_eff_bs, - args.reserved_bos, - args.reserved_eos, - ) - hstats = compute_histogram_stats(histograms) - hstats["raw_counts"] = histograms.sum(dim=0).tolist() # [256] ints - config_results["histogram"] = hstats - print( - f" histogram stats : max/mean={hstats['max_mean_ratio']:.2f} " - f"gini={hstats['gini']:.3f} " - f"nonzero_bins={hstats['num_nonzero_bins']}/256" - ) - - # Collect all histogram entries, then print sorted by gini - # Each entry: (display_name, key, mode_stats) - hist_entries = [] - histograms_results = {} - - # Per-mode histogram analysis (scaled) - modes_to_test = [0, 3, 4, 6, 7, 8, 9, 10, 11, 13, 14] - if mapping_lut is not None: - modes_to_test.append(1) - if mapping_quantiles is not None: - modes_to_test.append(2) - modes_to_test.sort() - - for mode in modes_to_test: - mode_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - - extra_lut = mapping_lut if mode == 1 else None - extra_q = mapping_quantiles if mode == 2 else None - power = _resolve_mode_hparam(args, mode) if mode in (3, 6, 7, 9, 10, 13, 14) else 0.5 - - topk_profile_histogram( - hist_inputs["x"], - hist_inputs["dense_kv_indptr"], - mode_hists, - hist_eff_bs, - args.reserved_bos, - args.reserved_eos, - mode, - power, - extra_lut, - extra_q, - False, # mapping_noscale - topk_val, # needed for mode 12/14 (tail/topk window) - ) - torch.cuda.synchronize() - - mode_stats = compute_histogram_stats(mode_hists) - mode_stats["raw_counts"] = mode_hists.sum(dim=0).tolist() - mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") - mformula = MAPPING_MODE_FORMULAS.get(mode, mname) - mode_stats["name"] = mname - mode_stats["formula"] = mformula - if mode in (3, 6, 7, 9, 10, 13, 14): - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha", 14: "rho"}[mode] - mode_stats["param"] = f"{pname}={power}" - display_name = f"{mname} ({pname}={power})" - else: - display_name = mname - key = f"mode_{mode}_{mname}" - histograms_results[key] = mode_stats - hist_entries.append((display_name, f"mode {mode:2d}", mode_stats)) - - # Noscale histogram analysis for parametric transform modes - noscale_modes = [m for m in (3, 6, 7, 9, 10, 13) if m in modes_to_test] - for mode in noscale_modes: - ns_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - power = _resolve_mode_hparam(args, mode) - topk_profile_histogram( - hist_inputs["x"], - hist_inputs["dense_kv_indptr"], - ns_hists, - hist_eff_bs, - args.reserved_bos, - args.reserved_eos, - mode, - power, - None, - None, - True, # mapping_noscale=True - ) - torch.cuda.synchronize() - ns_stats = compute_histogram_stats(ns_hists) - ns_stats["raw_counts"] = ns_hists.sum(dim=0).tolist() - mname = MAPPING_MODE_NAMES.get(mode, f"m{mode}") - mformula = MAPPING_MODE_FORMULAS.get(mode, mname) - pname = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 13: "alpha"}[mode] - ns_stats["name"] = f"{mname} noscale" - ns_stats["formula"] = mformula - ns_stats["param"] = f"{pname}={power}" - display_name = f"{mname} noscale ({pname}={power})" - key = f"mode_{mode}_{mname}_noscale" - histograms_results[key] = ns_stats - hist_entries.append((display_name, f"m{mode:2d} ns", ns_stats)) - - # Scale Only baseline: mode 3 with p=1.0 (identity + linear scaling) - scale_hists = torch.zeros(hist_eff_bs, 256, dtype=torch.int32, device="cuda") - topk_profile_histogram( - hist_inputs["x"], - hist_inputs["dense_kv_indptr"], - scale_hists, - hist_eff_bs, - args.reserved_bos, - args.reserved_eos, - 3, # mode 3 (power) - 1.0, # p=1.0 → identity transform - None, - None, - ) - torch.cuda.synchronize() - scale_stats = compute_histogram_stats(scale_hists) - scale_stats["raw_counts"] = scale_hists.sum(dim=0).tolist() - scale_stats["name"] = "Scale Only" - scale_stats["formula"] = "Identity + linear scaling to [0,255]" - scale_stats["param"] = "p=1.0" - histograms_results["mode_scale_Scale Only"] = scale_stats - hist_entries.append(("Scale Only (p=1.0)", "scale ", scale_stats)) - - # Print all histogram entries sorted by gini (ascending = more uniform = better) - hist_entries.sort(key=lambda e: e[2]['gini']) - print(f" --- histogram by gini (sorted, lower=better) ---") - for rank, (display_name, mode_tag, stats) in enumerate(hist_entries, 1): - print( - f" {rank:2d}. {display_name:<32s} ({mode_tag}): " - f"gini={stats['gini']:.3f} " - f"max/mean={stats['max_mean_ratio']:.2f} " - f"nonzero_bins={stats['num_nonzero_bins']}/256 " - f"eff_bins={stats['effective_bins']:.1f} " - f"entropy={stats['entropy']:.2f}" - ) - - config_results["histograms"] = histograms_results - - all_results.append(config_results) - - return all_results - - -def main(): - parser = argparse.ArgumentParser(description="TopK kernel benchmark suite") - parser.add_argument("--batch-sizes", nargs="+", type=int, default=[1, 4, 8, 16, 32, 64]) - parser.add_argument("--seq-lens", nargs="+", type=int, default=[1024, 2048, 4096, 8192]) - parser.add_argument("--topk-vals", nargs="+", type=int, default=[16, 30, 64]) - parser.add_argument("--num-kv-heads", nargs="+", type=int, default=[2, 4, 8]) - parser.add_argument("--page-size", type=int, default=16) - parser.add_argument("--reserved-bos", type=int, default=1) - parser.add_argument("--reserved-eos", type=int, default=2) - parser.add_argument("--score-dtype", choices=["bfloat16", "float32"], default="bfloat16") - parser.add_argument("--distributions", nargs="+", default=["normal", "lognormal", "uniform"]) - parser.add_argument("--warmup", type=int, default=10) - parser.add_argument("--repeat", type=int, default=100) - parser.add_argument("--mapping-hparam", "--mapping-power", type=float, default=0.5, - dest="mapping_hparam", - help="Global fallback hyperparameter for parametric modes (default: 0.5)") - parser.add_argument("--mapping-hparam-3", "--mapping-power-3", type=float, default=None, - dest="mapping_hparam_3", - help="Power exponent p for mode 3 (overrides --mapping-hparam)") - parser.add_argument("--mapping-hparam-6", "--mapping-power-6", type=float, default=None, - dest="mapping_hparam_6", - help="Beta for mode 6 asinh (overrides --mapping-hparam)") - parser.add_argument("--mapping-hparam-7", "--mapping-power-7", type=float, default=None, - dest="mapping_hparam_7", - help="Alpha for mode 7 log1p (overrides --mapping-hparam)") - parser.add_argument("--mapping-hparam-13", "--mapping-power-13", type=float, default=None, - dest="mapping_hparam_13", - help="Alpha for mode 13 exp_stretch (overrides --mapping-hparam)") - parser.add_argument("--mapping-hparam-14", "--mapping-power-14", type=float, default=None, - dest="mapping_hparam_14", - help="Rho for mode 14 topk_window (overrides --mapping-hparam)") - parser.add_argument("--autotune-json", type=str, default=None, - help="Path to autotune_results.json — extracts best per-mode hyperparameters " - "(overrides --mapping-power for modes 3/6/7/13/14)") - parser.add_argument("--lut-path", type=str, default=None, help="Path to .npy uint8[256] LUT for mode=1") - parser.add_argument("--quantiles-path", type=str, default=None, help="Path to .npy float32[256] for mode=2") - parser.add_argument("--output-json", type=str, default=None, help="Save results to JSON file") - parser.add_argument("--filter-kernels", nargs="+", default=None, - help="Only run specific kernels: naive, sglang_m0, sglang_m3, sglang_m4") - parser.add_argument("--histogram", action="store_true", help="Collect and report bin distribution statistics") - parser.add_argument("--histogram-pages", type=int, default=None, - help="Total pages for histogram profiling. Default: adaptive " - "(512 samples/bin × 256 bins, rounded to segment boundary). " - "Only used when --histogram is set.") - parser.add_argument("--real-histograms", type=str, default=None, - help="Path to .npy raw_histograms from calibration (adds 'real' distribution)") - parser.add_argument("--counters", action="store_true", - help="Collect diagnostic counters (threshold_bin, num_above, num_equal, " - "remaining_k, refine_rounds, stage2_input) for each sglang kernel") - parser.add_argument("--sample-stride", type=int, default=1, - help="Pre-pass sampling stride for mapped modes (1=full, 4=1/4, 8=1/8). " - "Higher values reduce pre-pass overhead at cost of bin quality (default: 1)") - parser.add_argument("--radix-bits", type=int, default=8, - help="Stage 1 radix bits for ori/mode-0 kernel: 4=16 bins, 6=64, 8=256, 9=512, 10=1024 (default: 8). " - "Range: 4-10. Fewer bits = coarser Stage 1 but faster histogram; more bits = finer but slower.") - - args = parser.parse_args() - results = run_benchmark(args) + results.append({ + "batch_size": bs, "num_kv_heads": heads, "seq_len": seq_len, + "topk_val": topk_val, "distribution": dist, "modes": row_modes, + }) if args.output_json: with open(args.output_json, "w") as f: @@ -945,5 +461,66 @@ def main(): print(f"\nResults saved to {args.output_json}") +def main(): + p = argparse.ArgumentParser("TopK kernel benchmarks") + p.add_argument("--batch-sizes", type=int, nargs="+", default=[4]) + p.add_argument("--num-kv-heads", type=int, nargs="+", default=[8]) + p.add_argument("--seq-lens", type=int, nargs="+", default=[8192]) + p.add_argument("--topk-vals", type=int, nargs="+", default=[30]) + p.add_argument("--distributions", type=str, nargs="+", + default=["normal"], + choices=["normal", "lognormal", "uniform", "bucket_uniform"]) + p.add_argument("--mapping-modes", type=int, nargs="+", + default=[0, 3, 6, 7], + help="Mapping modes to sweep (0=None, 3=Power, 6=Asinh, 7=Log1p, etc.)") + p.add_argument("--mapping-hparam", "--mapping-power", type=float, default=0.5, + dest="mapping_hparam", + help="Fallback hyperparameter for every non-zero mapping mode when " + "no --autotune-json is provided: p for mode 3 (power), beta for " + "mode 6 (asinh), alpha for modes 7/9/10/13 (log1p/erf/tanh/exp_stretch).") + p.add_argument("--autotune-json", type=str, default=None, + help="Path to autotune_results.json produced by autotune_topk_mapping.py. " + "When set, the per-mode hyperparameter with the lowest measured " + "latency in that file is used instead of --mapping-hparam.") + p.add_argument("--lut-path", type=str, default=None, + help="Path to .npy uint8[256] LUT for MAPPING_LUT_CDF (mode 1).") + p.add_argument("--quantiles-path", type=str, default=None, + help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") + p.add_argument("--page-size", type=int, default=16) + p.add_argument("--reserved-bos", type=int, default=1) + p.add_argument("--reserved-eos", type=int, default=2) + p.add_argument("--warmup", type=int, default=10) + p.add_argument("--repeat", type=int, default=100) + p.add_argument("--output-json", type=str, default=None) + p.add_argument("--remap-bench", action="store_true", + help="Run the split-phase remap/topk/fused/baseline benchmark.") + args = p.parse_args() + + args._autotune_hparams = {} + if args.autotune_json: + args._autotune_hparams = _load_autotune_hparams(args.autotune_json) + print(f"[autotune] using best-latency hyperparameters from {args.autotune_json}:") + for m, v in sorted(args._autotune_hparams.items()): + print(f" mode {m:>2d} -> {v}") + + args._mapping_lut = None + args._mapping_quantiles = None + if args.lut_path: + lut_np = np.load(args.lut_path).astype(np.uint8) + assert lut_np.shape == (256,), f"LUT must be [256], got {lut_np.shape}" + args._mapping_lut = torch.from_numpy(lut_np).cuda() + print(f"[mapping] loaded LUT from {args.lut_path}") + if args.quantiles_path: + q_np = np.load(args.quantiles_path).astype(np.float32) + assert q_np.shape == (256,), f"quantiles must be [256], got {q_np.shape}" + args._mapping_quantiles = torch.from_numpy(q_np).cuda() + print(f"[mapping] loaded quantiles from {args.quantiles_path}") + + if args.remap_bench: + _run_remap_bench(args) + else: + _run_latency_sweep(args) + + if __name__ == "__main__": main() diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py index 4c861161..e3524c1f 100644 --- a/benchmarks/calibrate_topk.py +++ b/benchmarks/calibrate_topk.py @@ -44,6 +44,14 @@ def main(): help="Number of calibration prompts to use (default: 16)") parser.add_argument("--output-dir", type=str, default="calibration_output/") parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") + parser.add_argument( + "--watchdog-timeout", + type=float, + default=None, + metavar="SEC", + help="SGLang scheduler watchdog (seconds). Forward batches must complete within this time. " + "Default: engine default (300). Use 0 to disable when using this repo's SGLang fork.", + ) args = parser.parse_args() # Lazy imports to avoid slow startup when just checking --help @@ -54,7 +62,7 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) print(f"[calibrate] Launching engine with hit-rate profiling enabled...") - llm = sgl.Engine( + engine_kwargs = dict( model_path=args.model_name, disable_cuda_graph=True, page_size=args.page_size, @@ -73,6 +81,9 @@ def main(): vortex_topk_mapping_mode=0, # Use mode 0 during calibration vortex_topk_histogram=True, # Enable histogram collection ) + if args.watchdog_timeout is not None: + engine_kwargs["watchdog_timeout"] = args.watchdog_timeout + llm = sgl.Engine(**engine_kwargs) # Clear any residual histograms in the worker process llm.clear_topk_histograms() diff --git a/benchmarks/greedy_layer_search.py b/benchmarks/greedy_layer_search.py deleted file mode 100644 index 118ac454..00000000 --- a/benchmarks/greedy_layer_search.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Greedy forward-selection of layers whose indexer can be skipped (index cache). - -Usage (from repo root): - cd examples && python ../benchmarks/greedy_layer_search.py \ - --model-name Qwen/Qwen3-1.7B --topk-val 30 --threshold 0.95 \ - --trials 1 --num-layers 28 --mem 0.7 - -The script prints progress to stderr and outputs the final selected layer list -(as a Python list literal) on the **last line of stdout** so callers can parse it. -""" - -import argparse -import os -import sys - -# Add examples/ to path so we can import verify_algos -_examples_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "examples") -sys.path.insert(0, _examples_dir) - -from verify_algo import verify_algos # noqa: E402 - - -def _evaluate(shared_layers, args): - """Run verify_algos with the given shared layers and return pass@trials accuracy.""" - summary = verify_algos( - trials=args.trials, - topk_val=args.topk_val, - page_size=args.page_size, - vortex_module_name=args.vortex_module_name, - model_name=args.model_name, - sparse_attention=True, - mem=args.mem, - kv_cache_dtype=args.kv_cache_dtype, - topk_type=args.topk_type, - topk_mapping_mode=0, - topk_mapping_power=args.topk_mapping_power, - index_cache_shared_layers=sorted(shared_layers) if shared_layers else None, - disable_cuda_graph=True, - ) - acc_key = f"pass@{args.trials}" - return summary[acc_key] - - -def greedy_search(args): - # Ensure we're in examples/ so amc23.jsonl relative path works - os.chdir(_examples_dir) - - candidates = list(range(1, args.num_layers)) - - # Baseline: no shared layers - print("Evaluating baseline (no shared layers)...", file=sys.stderr) - baseline_acc = _evaluate([], args) - print(f"Baseline accuracy: {baseline_acc:.4f}", file=sys.stderr) - - threshold = args.threshold - shared_set = [] - - while candidates: - best_layer = None - best_acc = -1.0 - - for layer in candidates: - trial_set = shared_set + [layer] - print(f" Trying shared_set={sorted(trial_set)} ...", file=sys.stderr, end=" ") - acc = _evaluate(trial_set, args) - print(f"acc={acc:.4f}", file=sys.stderr) - - if acc > best_acc: - best_acc = acc - best_layer = layer - - if best_acc >= threshold * baseline_acc: - shared_set.append(best_layer) - candidates.remove(best_layer) - print( - f"Added layer {best_layer} (acc={best_acc:.4f} >= " - f"{threshold * baseline_acc:.4f}). Current set: {sorted(shared_set)}", - file=sys.stderr, - ) - else: - print( - f"Stopping: best candidate layer {best_layer} acc={best_acc:.4f} < " - f"{threshold * baseline_acc:.4f}", - file=sys.stderr, - ) - break - - result = sorted(shared_set) - print(f"Final shared layers: {result}", file=sys.stderr) - # Last stdout line: parseable Python list - print(result) - return result - - -def parse_args(): - parser = argparse.ArgumentParser( - description="Greedy forward-selection of index-cache shared layers." - ) - parser.add_argument("--model-name", type=str, default="Qwen/Qwen3-1.7B") - parser.add_argument("--topk-val", type=int, default=30) - parser.add_argument("--page-size", type=int, default=16) - parser.add_argument("--vortex-module-name", type=str, default="block_sparse_attention") - parser.add_argument("--mem", type=float, default=0.8) - parser.add_argument("--kv-cache-dtype", type=str, default="auto") - parser.add_argument("--topk-type", type=str, default="naive") - parser.add_argument("--topk-mapping-power", type=float, default=0.5) - parser.add_argument("--threshold", type=float, default=0.95, - help="Minimum accuracy ratio vs baseline to keep adding layers (default: 0.95).") - parser.add_argument("--trials", type=int, default=1) - parser.add_argument("--num-layers", type=int, default=28, - help="Total number of model layers (default: 28 for Qwen3-1.7B).") - return parser.parse_args() - - -if __name__ == "__main__": - args = parse_args() - greedy_search(args) diff --git a/csrc/archived/README.md b/csrc/archived/README.md new file mode 100644 index 00000000..6e08a1dc --- /dev/null +++ b/csrc/archived/README.md @@ -0,0 +1,19 @@ +# Archived TopK kernels + +These files are **not compiled** (not listed in `setup.py`) and are kept only +for historical reference. + +- `topk_slgang_ori.cu` — the original SGLang TopK reference kernel (typo in + the filename is intentional, matches the upstream commit it was adapted + from). Superseded by the fused `fast_topk_vortex` path in + `../topk_sglang.cu`. +- `topk_sglang_ori_fastpath.cu` — the `fast_topk_ori` / + `TopKOutput_Ori_Kernel` / `launch_ori_kernel` code extracted out of + `../topk_sglang.cu`. It was the "zero mapping overhead" fast path with + flexible `radix_bits` (4–10). We no longer test it — mode 0 now goes + through the standard fused kernel with `MAPPING_NONE`, which pays no + mapping overhead because `mapped_convert_to_uint8` degenerates to + `convert_to_uint8` in that branch. + +If you need to resurrect any of this, add the `.cu` to `setup.py` and +re-export its entry points from `../register.cc` / `../register.h`. diff --git a/csrc/archived/fast_topk_vortex_prepass.cu b/csrc/archived/fast_topk_vortex_prepass.cu new file mode 100644 index 00000000..5b743f19 --- /dev/null +++ b/csrc/archived/fast_topk_vortex_prepass.cu @@ -0,0 +1,525 @@ +// Archived: not compiled. See csrc/archived/README.md +// +// fast_topk_vortex — the heavy fused remap+topk kernel with auto-range, +// pivot, tail-window, topk-window pre-passes and LUT/quantile support. +// Extracted from csrc/topk_sglang.cu as part of the remap-benchmark refactor. +// Replaced by a lean fast_topk_clean_fused that applies a simple element-wise +// transform (from topk_mapping.cuh apply_transform) in Stage-1 bucketing — +// no pre-pass, no LUT, no auto-range. +// +// References types/constants from its former translation unit (TopKMappingParams, +// needs_*, mapped_convert_to_uint8, kSmem, kThreadsPerBlock, COUNTER_*). This +// file will not compile standalone; kept for history only. + +// ====================================================================== +// Templated version of fast_topk_cuda_tl with mapping support: +// - ScoreT: float or __nv_bfloat16 +// - StopAfterStage1: return after Stage 1 route/filter (for profiling) +// - WriteCounters: write diagnostic counters to global memory + +// - mapping: configurable value-remapping for Stage 1 bin assignment +template +__device__ void fast_topk_vortex( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k, + const TopKMappingParams& mapping, + int* counters = nullptr) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int vh_counter; + alignas(128) __shared__ int vh_threshold_bin_id; + alignas(128) __shared__ int vh_num_input[2]; + + // Shared memory for mapping LUT / quantiles (loaded once per block) + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; + + // Auto-range for transform modes (3/4/6/7) + __shared__ float s_range_min, s_range_inv_range; + + auto& vh_histogram = vh_histogram_buf[0]; + extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Load mapping tables into shared memory if needed + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } + + // Pre-pass: compute per-block min/max of transformed values for linear bucketing. + // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; + // the approximated range may miss extreme outliers but Stage 2 uses raw + // float bits for exact ordering, so correctness is preserved. + if (needs_auto_range(mapping.mode) && !mapping.noscale) { + const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; + float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; + for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { + float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); + local_min = fminf(local_min, val); + local_max = fmaxf(local_max, val); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + // Cross-warp reduction via shared memory + __shared__ float s_warp_mins[32], s_warp_maxs[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + } + if (tx == 0) { + s_range_min = local_min; + float range = local_max - local_min; + s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else if (needs_pivot(mapping.mode)) { + // Pivot pre-pass: compute mean of all elements, store in s_range_min. + // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering + // around the mean helps distribute values more evenly across bins. + float local_sum = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + local_sum += vortex_to_float(input[idx + row_start]); + } + // Warp-level reduction + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + __shared__ float s_warp_sums[32]; + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_sums[warp_id] = local_sum; + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_sum = s_warp_sums[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + } + if (tx == 0) { + s_range_min = local_sum / float(length); // mean as pivot + s_range_inv_range = 0.0f; + } + } + __syncthreads(); + } else if (needs_tail_window(mapping.mode)) { + // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) + // and local_max via a sampled quantile estimator. All 256 coarse bins + // are then allocated to [tau_low, local_max]; scores below tau_low + // collapse into bin 0 via linear_map_to_uint8 clamping. + constexpr int MAX_SAMPLES = 1024; + __shared__ float s_samples[MAX_SAMPLES]; + __shared__ int s_sample_count; + + if (tx == 0) s_sample_count = 0; + __syncthreads(); + + // Compute sampling stride so we collect ~MAX_SAMPLES from the segment + const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; + const int sample_stride = max(desired_stride, 1); + + // Each thread samples elements and finds local_max simultaneously + float local_max = -__FLT_MAX__; + for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + int slot = ::atomicAdd(&s_sample_count, 1); + if (slot < MAX_SAMPLES) { + s_samples[slot] = val; + } + } + + // Reduce local_max across block + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + __shared__ float s_warp_maxs_tw[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw[tx]; + for (int offset = 16; offset > 0; offset >>= 1) + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + if (tx == 0) s_warp_maxs_tw[0] = local_max; + } + __syncthreads(); + local_max = s_warp_maxs_tw[0]; + + int nsamp = min(s_sample_count, MAX_SAMPLES); + + // Simple odd-even transposition sort on the sample buffer. + // nsamp <= 1024, and we have 1024 threads, so each thread + // handles one element. O(nsamp) parallel rounds suffice. + __syncthreads(); + if (nsamp >= 2) { + for (int pass = 0; pass < nsamp; ++pass) { + // Even phase: compare (0,1), (2,3), ... + if (tx * 2 + 1 < nsamp) { + int i = tx * 2; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + // Odd phase: compare (1,2), (3,4), ... + if (tx * 2 + 2 < nsamp) { + int i = tx * 2 + 1; + if (s_samples[i] > s_samples[i + 1]) { + float tmp = s_samples[i]; + s_samples[i] = s_samples[i + 1]; + s_samples[i + 1] = tmp; + } + } + __syncthreads(); + } + } + + // Estimate tau_low = Q(1 - rho * k / n) + if (tx == 0) { + float rho = mapping.power_exp; // reused as tail expansion factor + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float frac = 1.0f - rho * float(k) / float(length); + frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + + float tau_low; + if (nsamp < 4 || frac <= 0.0f) { + // Too few samples or the tail covers everything: full range + tau_low = -__FLT_MAX__; + } else { + float fidx = frac * float(nsamp - 1); + int lo = __float2int_rd(fidx); + lo = min(max(lo, 0), nsamp - 2); + float t = fidx - float(lo); + tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; + } + + // Fallback: if tau_low >= local_max, use full-range linear mapping + if (tau_low >= local_max) { + // Find the actual minimum from sorted samples + tau_low = (nsamp > 0) ? s_samples[0] : local_max; + } + + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + __syncthreads(); + } else if (needs_topk_window(mapping.mode)) { + // Topk-window pre-pass with streaming variance heuristic. + // tau_low = max - rho * sigma * sqrt(2 * log(n/k)) + float local_max = -__FLT_MAX__; + float local_sum = 0.0f, local_sum_sq = 0.0f; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + float val = vortex_to_float(input[idx + row_start]); + local_max = fmaxf(local_max, val); + local_sum += val; + local_sum_sq += val * val; + } + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; + { + int warp_id = tx >> 5, lane_id = tx & 31; + if (lane_id == 0) { + s_warp_maxs_tw2[warp_id] = local_max; + s_warp_sums_tw2[warp_id] = local_sum; + s_warp_sq_tw2[warp_id] = local_sum_sq; + } + } + __syncthreads(); + if (tx < (BLOCK_SIZE >> 5)) { + local_max = s_warp_maxs_tw2[tx]; + local_sum = s_warp_sums_tw2[tx]; + local_sum_sq = s_warp_sq_tw2[tx]; + for (int offset = 16; offset > 0; offset >>= 1) { + local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); + local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); + local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); + } + if (tx == 0) { + float rho = mapping.power_exp; + if (rho <= 0.0f) rho = 4.0f; + int k = (mapping.target_k > 0) ? mapping.target_k : target_k; + float n = float(length); + float mean = local_sum / n; + float var = local_sum_sq / n - mean * mean; + float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; + float ratio = n / fmaxf(float(k), 1.0f); + float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); + float tau_low = local_max - rho * sigma * z; + if (tau_low >= local_max) tau_low = local_max - 1.0f; + float range = local_max - tau_low; + s_range_min = tau_low; + s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; + } + } + __syncthreads(); + } else { + if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } + __syncthreads(); + } + + // Stage 1: 8-bit coarse histogram (with optional mapping) + // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) + // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. + // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. + constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries + uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); + const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); + + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range); + ::atomicAdd(&vh_histogram[bin], 1); + if (use_bin_cache) { + bin_cache[idx] = bin; + } + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = vh_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += vh_histogram_buf[k][tx + j]; + } + vh_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[0] = 0; + vh_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin; + counters[COUNTER_REMAINING_K] = topk; + } + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8( + vortex_to_float(input[idx + row_start]), + mapping, s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = 0; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + int bin; + if (use_bin_cache) { + bin = static_cast(bin_cache[idx]); + } else { + bin = static_cast( + mapped_convert_to_uint8(raw_input, mapping, + s_mapping_lut, s_mapping_quantiles, + s_range_min, s_range_inv_range)); + } + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&vh_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + if (WriteCounters && tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = vh_counter; + counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; + counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; + } + if (StopAfterStage1) return; + } + + // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) + if constexpr (WriteCounters) { + // Default: all 4 rounds used; overwritten at break if resolved early + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + } +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int vh_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = vh_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) + ? _raw_num_input + : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { + vh_threshold_bin_id = tx; + vh_num_input[r_idx ^ 1] = 0; + vh_last_remain = topk - vh_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = vh_threshold_bin_id; + topk -= vh_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32( + vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if constexpr (WriteCounters) { + if (tx == 0 && counters) { + counters[COUNTER_REFINE_ROUNDS] = round + 1; + } + } + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = vh_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&vh_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&vh_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + vh_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&vh_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Wrapper kernel: one CUDA block per batch*head segment +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); + + // Remap position indices -> page indices via dense_kv_indices + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + + diff --git a/csrc/archived/topk_mapping_full.cuh b/csrc/archived/topk_mapping_full.cuh new file mode 100644 index 00000000..f85204ec --- /dev/null +++ b/csrc/archived/topk_mapping_full.cuh @@ -0,0 +1,217 @@ +// Archived: not included by any compiled TU. See csrc/archived/README.md. +// The full mapping header supporting LUT_CDF, QUANTILE, TRUNC8, SUBTRACT, +// ADAPTIVE_TAIL_WINDOW, TOPK_WINDOW and the auto-range/pivot/tail-window +// pre-pass infrastructure. Replaced by the lean element-wise-only header +// at csrc/topk_mapping.cuh for the remap-benchmark refactor. +#pragma once +#include +#include +#include + +// ============================================================ +// TopK bucket-sort Stage-1 remapping strategies +// +// These transforms remap float scores before Stage 1's 8-bit +// histogram binning. The primary goal is to maximize coarse-bin +// resolution in the score region that determines the top-k +// cutoff, thereby: +// - shrinking the Stage-1 threshold bin (fewer collisions) +// - reducing COUNTER_NUM_EQUAL / COUNTER_STAGE2_INPUT +// - reducing the number of Stage-2 refine rounds +// +// Stage 2 refinement still uses convert_to_uint32() on raw +// floats, so final ordering correctness is always preserved. +// +// Modes 3/4/6/7/9/10 apply a nonlinear transform then linearly +// map the result to [0,255]. Mode 12 (ADAPTIVE_TAIL_WINDOW) +// directly focuses all 256 bins on the competitive upper tail +// estimated from the top-k ratio, collapsing irrelevant +// low-score mass into bin 0. +// ============================================================ + +enum TopKMappingMode { + MAPPING_NONE = 0, // Original convert_to_uint8 behavior + MAPPING_LUT_CDF = 1, // LUT-based CDF equalization + MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping + MAPPING_POWER = 3, // Monotonic power transform + MAPPING_LOG = 4, // Log transform + // Mode 5 reserved (previously INDEX_CACHE, removed) + MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp + MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing + MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile + MAPPING_EXP_STRETCH = 13, // exp(alpha * x), concentrates bin resolution on upper tail + MAPPING_TOPK_WINDOW = 14, // k-aware linear windowing: focus bins on [tau_low, max] +}; + +struct TopKMappingParams { + int mode; // TopKMappingMode + float power_exp; // For MAPPING_POWER (default 0.5) + // For MAPPING_ADAPTIVE_TAIL_WINDOW: tail expansion + // factor rho (default 4.0). tau_low = Q(1 - rho*k/n). + const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr + const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr + bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) + int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) + int target_k; // Top-k value; used by MAPPING_ADAPTIVE_TAIL_WINDOW +}; + +// NOTE: convert_to_uint8() must be defined before including this header. +// It is defined in topk_sglang.cu within the anonymous namespace. + +// ---- Individual transform functions (return float, no bucketing) ---- + +__device__ __forceinline__ float transform_power(float x, float p) { + return copysignf(__powf(fabsf(x), p), x); +} + +__device__ __forceinline__ float transform_log(float x) { + return copysignf(__logf(fabsf(x) + 1.0f), x); +} + +__device__ __forceinline__ float transform_asinh(float x, float beta) { + return asinhf(beta * x); +} + +__device__ __forceinline__ float transform_log1p(float x, float alpha) { + return copysignf(log1pf(alpha * fabsf(x)), x); +} + +__device__ __forceinline__ float transform_erf(float x, float alpha) { + return erff(alpha * x); +} + +__device__ __forceinline__ float transform_tanh(float x, float alpha) { + return tanhf(alpha * x); +} + +__device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { + float z = alpha * x; + z = fminf(z, 80.0f); // prevent float32 overflow (exp(80) ~ 5.5e34) + return expf(z); +} + +// ---- Transform dispatcher (returns float, no bucketing) ---- + +__device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { + switch (params.mode) { + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + default: return x; + } +} + +// ---- Linear bucketing for transform modes ---- + +__device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_min, float inv_range) { + int bin = __float2int_rd((val - range_min) * inv_range); + return static_cast(min(max(bin, 0), 255)); +} + +// ---- BF16-aware bucketing (mode 8) ---- +// BF16 has 8 exponent + 7 mantissa bits. Taking the upper 8 bits of the +// sign-flipped bf16 bit-pattern yields only ~20 distinct bins for typical +// data (the byte is almost entirely exponent). Instead, convert through +// fp16 (5 exp + 10 mantissa) which puts 5 exp + 2 mantissa bits in the +// upper byte, giving ~135+ distinct bins — equivalent to mode 0 but +// explicitly available as a named mode for documentation/benchmarking. + +__device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { + return convert_to_uint8(x); // fp16 sign-flip bucketing +} + +// ---- Non-transform mapping functions (unchanged) ---- + +// LUT-based CDF equalization: lut[original_bin] -> equalized_bin +__device__ __forceinline__ uint8_t map_lut_cdf(float x, const uint8_t* __restrict__ s_lut) { + return s_lut[convert_to_uint8(x)]; +} + +// Quantile mapping: binary search over 256 sorted thresholds +__device__ __forceinline__ uint8_t map_quantile(float x, const float* __restrict__ s_quantiles) { + // Binary search: find largest index i such that x >= s_quantiles[i] + // s_quantiles is sorted ascending, length 256 + int lo = 0, hi = 255; +#pragma unroll 8 + for (int iter = 0; iter < 8; ++iter) { + int mid = (lo + hi + 1) >> 1; + if (x >= s_quantiles[mid]) { + lo = mid; + } else { + hi = mid - 1; + } + } + return static_cast(lo); +} + +// ---- Unified dispatcher ---- +// For modes 3/4/6/7, range_min and inv_range come from a per-block pre-pass. + +__device__ __forceinline__ uint8_t mapped_convert_to_uint8( + float x, + const TopKMappingParams& params, + const uint8_t* __restrict__ s_lut, + const float* __restrict__ s_quantiles, + float range_min, + float inv_range) +{ + switch (params.mode) { + case MAPPING_LUT_CDF: + if (params.lut != nullptr) return map_lut_cdf(x, s_lut); + return convert_to_uint8(x); // fallback to mode 0 when LUT not calibrated + case MAPPING_QUANTILE: + if (params.quantiles != nullptr) return map_quantile(x, s_quantiles); + return convert_to_uint8(x); // fallback to mode 0 when quantiles not calibrated + case MAPPING_POWER: + case MAPPING_LOG: + case MAPPING_ASINH: + case MAPPING_LOG1P: + case MAPPING_ERF: + case MAPPING_TANH: + case MAPPING_EXP_STRETCH: { + float val = apply_transform(x, params); + if (params.noscale) return convert_to_uint8(val); + return linear_map_to_uint8(val, range_min, inv_range); + } + case MAPPING_TRUNC8: + return convert_to_uint8_bf16(x); + case MAPPING_SUBTRACT: + return convert_to_uint8(x - range_min); // range_min repurposed as pivot + case MAPPING_ADAPTIVE_TAIL_WINDOW: + case MAPPING_TOPK_WINDOW: + return linear_map_to_uint8(x, range_min, inv_range); + default: // MAPPING_NONE + return convert_to_uint8(x); + } +} + +// Helper: check if a mapping mode needs the auto-range pre-pass +__device__ __forceinline__ bool needs_auto_range(int mode) { + return (mode == MAPPING_POWER || mode == MAPPING_LOG || + mode == MAPPING_ASINH || mode == MAPPING_LOG1P || + mode == MAPPING_ERF || mode == MAPPING_TANH || + mode == MAPPING_EXP_STRETCH); +} + +// Helper: check if a mapping mode needs the pivot pre-pass +__device__ __forceinline__ bool needs_pivot(int mode) { + return (mode == MAPPING_SUBTRACT); +} + +// Helper: check if mode is the adaptive tail-window pre-pass +__device__ __forceinline__ bool needs_tail_window(int mode) { + return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); +} + +// Helper: check if mode is the lightweight topk-window pre-pass +__device__ __forceinline__ bool needs_topk_window(int mode) { + return (mode == MAPPING_TOPK_WINDOW); +} diff --git a/csrc/archived/topk_sglang_ori_fastpath.cu b/csrc/archived/topk_sglang_ori_fastpath.cu new file mode 100644 index 00000000..29970ecd --- /dev/null +++ b/csrc/archived/topk_sglang_ori_fastpath.cu @@ -0,0 +1,319 @@ +// Archived: not compiled. See csrc/archived/README.md +// +// Flexible-radix (RADIX_BITS 4..10) "ori fast path" for TopK. It was the +// zero-mapping-overhead fast path used when mapping_mode == MAPPING_NONE. +// No longer tested — mode 0 now routes through the fused TopKOutput_Kernel +// with mapping.mode == MAPPING_NONE, which pays no extra cost because +// mapped_convert_to_uint8 collapses to convert_to_uint8 in that branch. +// +// The code below was extracted verbatim from csrc/topk_sglang.cu as of the +// fused-kernel refactor. It references helpers (kSmem, convert_to_uint32, +// vortex_to_float, VORTEX_MAX_TOPK, kThreadsPerBlock, setup_kernel_smem_once, +// CHECK_CUDA, topk_mapping.cuh types) from the surrounding translation unit. +// Dropping this file into a build as-is will not compile; it is reference +// only. + +template +__device__ __forceinline__ uint16_t convert_to_uintN(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return key >> (16 - BITS); +} + +// ====================================================================== +// Ori fast path: zero-overhead topk with no mapping infrastructure. +// Template on RADIX_BITS: 4-10 (16 to 1024 bins). +// ====================================================================== +template +__device__ void fast_topk_ori( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 1 << RADIX_BITS; + constexpr auto RADIX_PAD = RADIX / 2; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + static_assert(RADIX_BITS >= 4 && RADIX_BITS <= 10, "RADIX_BITS must be 4-10"); + static_assert(RADIX <= BLOCK_SIZE, "RADIX must not exceed BLOCK_SIZE"); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + RADIX_PAD]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // Stage 1: coarse histogram with RADIX bins + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uintN(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + for (int i = 0; i < RADIX_BITS; ++i) { + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + // Stage 2 cumsum: always 256 sub-bins (8-bit radix on raw float bits) + const auto run_cumsum_s2 = [&] { + for (int i = 0; i < 8; ++i) { + if (C10_LIKELY(tx < 256)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < 256 - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uintN(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < 257) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uintN(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // Stage 2: refine with 8-bit radix passes +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum_s2(); + if (tx < 256 && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < 257) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// Ori fast-path wrapper: zero mapping overhead, flexible radix +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Ori_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int page_reserved_bos, + const int page_reserved_eos) +{ + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } +} + +// Helper: launch TopKOutput_Ori_Kernel with radix_bits dispatch +template +void launch_ori_kernel( + const ScoreT* score, const int* dense_kv_indptr, const int* sparse_kv_indptr, + const int* dense_kv_indices, int* sparse_kv_indices, + int topk_val, int reserved_bos, int reserved_eos, + int radix_bits, dim3 nblks, dim3 nthreads, cudaStream_t stream) +{ + #define LAUNCH_ORI(BITS) \ + setup_kernel_smem_once, kSmem>(); \ + TopKOutput_Ori_Kernel<<>>( \ + score, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, \ + topk_val, reserved_bos, reserved_eos) + switch (radix_bits) { + case 4: LAUNCH_ORI(4); break; + case 5: LAUNCH_ORI(5); break; + case 6: LAUNCH_ORI(6); break; + case 7: LAUNCH_ORI(7); break; + case 9: LAUNCH_ORI(9); break; + case 10: LAUNCH_ORI(10); break; + default: LAUNCH_ORI(8); break; + } + #undef LAUNCH_ORI +} + +// ====================================================================== +// Explicit ori baseline entry point — always uses the ori fast path +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t radix_bits) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_ori: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, + "topk_output_sglang_ori: radix_bits must be 4-10, got ", radix_bits); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + launch_ori_kernel<__nv_bfloat16>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); + } else if (x.scalar_type() == at::ScalarType::Float) { + launch_ori_kernel( + x.data_ptr(), + dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, + radix_bits, nblks, nthreads, stream); + } else { + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); +} diff --git a/csrc/topk_slgang_ori.cu b/csrc/archived/topk_slgang_ori.cu similarity index 100% rename from csrc/topk_slgang_ori.cu rename to csrc/archived/topk_slgang_ori.cu diff --git a/csrc/register.cc b/csrc/register.cc index b2d12b9b..cc201c98 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -13,21 +13,24 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), py::arg("eff_batch_size"), py::arg("topk_val"), py::arg("reserved_bos"), py::arg("reserved_eos"), - py::arg("max_num_pages"), - py::arg("mapping_mode") = 0, - py::arg("mapping_power") = 0.5, - py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false, - py::arg("sample_stride") = 1, - py::arg("radix_bits") = 8); - m.def("topk_output_sglang_ori", &topk_output_sglang_ori, + py::arg("max_num_pages")); + m.def("topk_output_sglang_fused", &topk_output_sglang_fused, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), py::arg("eff_batch_size"), py::arg("topk_val"), py::arg("reserved_bos"), py::arg("reserved_eos"), py::arg("max_num_pages"), - py::arg("radix_bits") = 8); + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); + m.def("topk_remap_only", &topk_remap_only, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("remapped"), + py::arg("eff_batch_size"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("mapping_mode"), + py::arg("mapping_power")); m.def("topk_profile_histogram", &topk_profile_histogram, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("histograms"), py::arg("eff_batch_size"), @@ -35,21 +38,7 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_mode") = 0, py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false, - py::arg("topk_val") = 0, - py::arg("sample_stride") = 1); - m.def("topk_profile_stage1", &topk_profile_stage1, - py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), - py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), - py::arg("eff_batch_size"), py::arg("topk_val"), - py::arg("reserved_bos"), py::arg("reserved_eos"), - py::arg("max_num_pages"), - py::arg("mapping_mode") = 0, - py::arg("mapping_power") = 0.5, - py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false); + py::arg("mapping_quantiles") = py::none()); m.def("topk_profile_counters", &topk_profile_counters, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), @@ -60,8 +49,7 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_mode") = 0, py::arg("mapping_power") = 0.5, py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none(), - py::arg("mapping_noscale") = false); + py::arg("mapping_quantiles") = py::none()); m.def("sglang_plan_decode_fa3", &sglang_plan_decode_fa3); m.def("sglang_plan_prefill_fa3", &sglang_plan_prefill_fa3); m.def("Chunkwise_HN2NH_Transpose_FA3", &Chunkwise_HN2NH_Transpose_FA3); diff --git a/csrc/register.h b/csrc/register.h index 784b754b..e86a9638 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -95,17 +95,10 @@ const int64_t eff_batch_size, const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t max_seq_lengths, -const int64_t mapping_mode = 0, -const double mapping_power = 0.5, -std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false, -const int64_t sample_stride = 1, -const int64_t radix_bits = 8 +const int64_t max_seq_lengths ); -void topk_output_sglang_ori( +void topk_output_sglang_fused( const at::Tensor& x, const at::Tensor& dense_kv_indptr, const at::Tensor& sparse_kv_indptr, @@ -116,41 +109,34 @@ const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, const int64_t max_num_pages, -const int64_t radix_bits = 8 +const int64_t mapping_mode, +const double mapping_power, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt ); -void topk_profile_histogram( +void topk_remap_only( const at::Tensor& x, const at::Tensor& dense_kv_indptr, -at::Tensor& histograms, +at::Tensor& remapped, const int64_t eff_batch_size, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t mapping_mode = 0, -const double mapping_power = 0.5, -std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false, -const int64_t topk_val = 0, -const int64_t sample_stride = 1 +const int64_t mapping_mode, +const double mapping_power ); -void topk_profile_stage1( +void topk_profile_histogram( const at::Tensor& x, const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, -const at::Tensor& dense_kv_indices, -at::Tensor& sparse_kv_indices, +at::Tensor& histograms, const int64_t eff_batch_size, -const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, -const int64_t max_num_pages, const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false +std::optional mapping_quantiles = std::nullopt ); void topk_profile_counters( @@ -168,8 +154,7 @@ const int64_t max_num_pages, const int64_t mapping_mode = 0, const double mapping_power = 0.5, std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt, -const bool mapping_noscale = false +std::optional mapping_quantiles = std::nullopt ); void sglang_plan_decode_fa3( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 773cdeb9..447e5397 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -4,60 +4,46 @@ #include // ============================================================ -// TopK bucket-sort Stage-1 remapping strategies +// TopK bucket-sort Stage-1 remap transforms (lean version). // -// These transforms remap float scores before Stage 1's 8-bit -// histogram binning. The primary goal is to maximize coarse-bin -// resolution in the score region that determines the top-k -// cutoff, thereby: -// - shrinking the Stage-1 threshold bin (fewer collisions) -// - reducing COUNTER_NUM_EQUAL / COUNTER_STAGE2_INPUT -// - reducing the number of Stage-2 refine rounds +// These are element-wise transforms applied to scores before +// the Stage-1 8-bit histogram bucketing. The goal is to spread +// a skewed raw distribution more uniformly across the 256 bins +// so the threshold bin shrinks and Stage-2 refinement does less +// work. Stage 2 still uses convert_to_uint32() on the remapped +// value's raw bits for tie-breaking. // -// Stage 2 refinement still uses convert_to_uint32() on raw -// floats, so final ordering correctness is always preserved. -// -// Modes 3/4/6/7/9/10 apply a nonlinear transform then linearly -// map the result to [0,255]. Mode 12 (ADAPTIVE_TAIL_WINDOW) -// directly focuses all 256 bins on the competitive upper tail -// estimated from the top-k ratio, collapsing irrelevant -// low-score mass into bin 0. +// There is no pre-pass, no auto-range, no LUT, no quantile +// table, and no shared-memory state — each transform is a +// pure function of one float. The heavy pre-pass machinery +// (auto-range, pivot, tail-window, topk-window, LUT_CDF, +// QUANTILE, SUBTRACT, TRUNC8) lives in +// csrc/archived/fast_topk_vortex_prepass.cu. // ============================================================ enum TopKMappingMode { - MAPPING_NONE = 0, // Original convert_to_uint8 behavior - MAPPING_LUT_CDF = 1, // LUT-based CDF equalization - MAPPING_QUANTILE = 2, // Piecewise-linear quantile mapping - MAPPING_POWER = 3, // Monotonic power transform - MAPPING_LOG = 4, // Log transform - // Mode 5 reserved (previously INDEX_CACHE, removed) - MAPPING_ASINH = 6, // asinh(beta * x), beta via power_exp - MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|), alpha via power_exp - MAPPING_TRUNC8 = 8, // BF16 upper-8-bit bucketing - MAPPING_ERF = 9, // erf(alpha * x) - MAPPING_TANH = 10, // tanh(alpha * x) - MAPPING_SUBTRACT = 11, // subtract pivot, then fp16 bucketing - MAPPING_ADAPTIVE_TAIL_WINDOW = 12, // focus bins on upper tail via sampled quantile - MAPPING_EXP_STRETCH = 13, // exp(alpha * x), concentrates bin resolution on upper tail - MAPPING_TOPK_WINDOW = 14, // k-aware linear windowing: focus bins on [tau_low, max] + MAPPING_NONE = 0, // identity (no remap) + MAPPING_LUT_CDF = 1, // bin lookup: new_bin = lut[convert_to_uint8(x)] + MAPPING_QUANTILE = 2, // binary search over 256 calibrated quantile thresholds + MAPPING_POWER = 3, // sign(x) * |x|^p + MAPPING_LOG = 4, // sign(x) * log(|x| + 1) + MAPPING_ASINH = 6, // asinh(beta * x) + MAPPING_LOG1P = 7, // sign(x) * log1p(alpha * |x|) + MAPPING_TRUNC8 = 8, // identity bucketing (historical name, alias of MAPPING_NONE) + MAPPING_ERF = 9, // erf(alpha * x) + MAPPING_TANH = 10, // tanh(alpha * x) + MAPPING_SUBTRACT = 11, // x - pivot, with pivot = power_exp (free hyperparameter) + MAPPING_EXP_STRETCH = 13, // exp(alpha * x) }; struct TopKMappingParams { - int mode; // TopKMappingMode - float power_exp; // For MAPPING_POWER (default 0.5) - // For MAPPING_ADAPTIVE_TAIL_WINDOW: tail expansion - // factor rho (default 4.0). tau_low = Q(1 - rho*k/n). - const uint8_t* __restrict__ lut; // [256] byte LUT, or nullptr - const float* __restrict__ quantiles; // [256] float quantile breakpoints, or nullptr - bool noscale; // Skip auto-range linear scaling, use fp16 bucketing on f(x) - int sample_stride; // Pre-pass sampling stride (1=full, 8=1/8, 0=skip) - int target_k; // Top-k value; used by MAPPING_ADAPTIVE_TAIL_WINDOW + int mode; // TopKMappingMode + float power_exp; // Free hyperparameter: p / alpha / beta / pivot depending on mode + const uint8_t* __restrict__ lut; // [256] uint8 LUT, MAPPING_LUT_CDF only + const float* __restrict__ quantiles; // [256] float quantile breakpoints, MAPPING_QUANTILE only }; -// NOTE: convert_to_uint8() must be defined before including this header. -// It is defined in topk_sglang.cu within the anonymous namespace. - -// ---- Individual transform functions (return float, no bucketing) ---- +// ---- Element-wise transforms ---- __device__ __forceinline__ float transform_power(float x, float p) { return copysignf(__powf(fabsf(x), p), x); @@ -89,124 +75,66 @@ __device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { return expf(z); } -// ---- Transform dispatcher (returns float, no bucketing) ---- - +// Pure element-wise dispatcher. Returns the *float value* after the transform. +// For bin-selection modes (LUT_CDF / QUANTILE) this is identity: the mapping +// happens in compute_stage1_bin() below instead of via a float transform, so +// Stage-2 tie-breaking uses the raw score bits for those modes. __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { switch (params.mode) { - case MAPPING_POWER: return transform_power(x, params.power_exp); - case MAPPING_LOG: return transform_log(x); - case MAPPING_ASINH: return transform_asinh(x, params.power_exp); - case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); - case MAPPING_ERF: return transform_erf(x, params.power_exp); - case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_SUBTRACT: return x - params.power_exp; case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); - default: return x; + case MAPPING_LUT_CDF: + case MAPPING_QUANTILE: + case MAPPING_TRUNC8: + default: return x; // NONE / TRUNC8 / LUT_CDF / QUANTILE } } -// ---- Linear bucketing for transform modes ---- - -__device__ __forceinline__ uint8_t linear_map_to_uint8(float val, float range_min, float inv_range) { - int bin = __float2int_rd((val - range_min) * inv_range); - return static_cast(min(max(bin, 0), 255)); -} - -// ---- BF16-aware bucketing (mode 8) ---- -// BF16 has 8 exponent + 7 mantissa bits. Taking the upper 8 bits of the -// sign-flipped bf16 bit-pattern yields only ~20 distinct bins for typical -// data (the byte is almost entirely exponent). Instead, convert through -// fp16 (5 exp + 10 mantissa) which puts 5 exp + 2 mantissa bits in the -// upper byte, giving ~135+ distinct bins — equivalent to mode 0 but -// explicitly available as a named mode for documentation/benchmarking. - -__device__ __forceinline__ uint8_t convert_to_uint8_bf16(float x) { - return convert_to_uint8(x); // fp16 sign-flip bucketing +// Whether the mapping mode is a direct bin-selection function (LUT_CDF / +// QUANTILE). These modes need per-block shared-memory tables. +__device__ __forceinline__ bool mapping_uses_table(int mode) { + return mode == MAPPING_LUT_CDF || mode == MAPPING_QUANTILE; } -// ---- Non-transform mapping functions (unchanged) ---- - -// LUT-based CDF equalization: lut[original_bin] -> equalized_bin -__device__ __forceinline__ uint8_t map_lut_cdf(float x, const uint8_t* __restrict__ s_lut) { - return s_lut[convert_to_uint8(x)]; -} - -// Quantile mapping: binary search over 256 sorted thresholds -__device__ __forceinline__ uint8_t map_quantile(float x, const float* __restrict__ s_quantiles) { - // Binary search: find largest index i such that x >= s_quantiles[i] - // s_quantiles is sorted ascending, length 256 +// Binary search over a sorted [256] quantile table. Returns the largest +// index i such that x >= quantiles[i], in [0, 255]. +__device__ __forceinline__ uint8_t quantile_bin_lookup( + float x, const float* __restrict__ s_quantiles) +{ int lo = 0, hi = 255; #pragma unroll 8 for (int iter = 0; iter < 8; ++iter) { int mid = (lo + hi + 1) >> 1; - if (x >= s_quantiles[mid]) { - lo = mid; - } else { - hi = mid - 1; - } + if (x >= s_quantiles[mid]) lo = mid; + else hi = mid - 1; } return static_cast(lo); } -// ---- Unified dispatcher ---- -// For modes 3/4/6/7, range_min and inv_range come from a per-block pre-pass. +// Forward decl so compute_stage1_bin can call it. Defined in the enclosing TU. +__device__ __forceinline__ uint8_t convert_to_uint8(float x); -__device__ __forceinline__ uint8_t mapped_convert_to_uint8( - float x, +// Compute the Stage-1 bin for a raw score under any mapping mode. LUT_CDF / +// QUANTILE use the shared-memory tables loaded at the kernel entry; every +// other mode falls back to convert_to_uint8(apply_transform(x)). +__device__ __forceinline__ uint8_t compute_stage1_bin( + float raw, const TopKMappingParams& params, const uint8_t* __restrict__ s_lut, - const float* __restrict__ s_quantiles, - float range_min, - float inv_range) + const float* __restrict__ s_quantiles) { switch (params.mode) { case MAPPING_LUT_CDF: - if (params.lut != nullptr) return map_lut_cdf(x, s_lut); - return convert_to_uint8(x); // fallback to mode 0 when LUT not calibrated + return s_lut[convert_to_uint8(raw)]; case MAPPING_QUANTILE: - if (params.quantiles != nullptr) return map_quantile(x, s_quantiles); - return convert_to_uint8(x); // fallback to mode 0 when quantiles not calibrated - case MAPPING_POWER: - case MAPPING_LOG: - case MAPPING_ASINH: - case MAPPING_LOG1P: - case MAPPING_ERF: - case MAPPING_TANH: - case MAPPING_EXP_STRETCH: { - float val = apply_transform(x, params); - if (params.noscale) return convert_to_uint8(val); - return linear_map_to_uint8(val, range_min, inv_range); - } - case MAPPING_TRUNC8: - return convert_to_uint8_bf16(x); - case MAPPING_SUBTRACT: - return convert_to_uint8(x - range_min); // range_min repurposed as pivot - case MAPPING_ADAPTIVE_TAIL_WINDOW: - case MAPPING_TOPK_WINDOW: - return linear_map_to_uint8(x, range_min, inv_range); - default: // MAPPING_NONE - return convert_to_uint8(x); + return quantile_bin_lookup(raw, s_quantiles); + default: + return convert_to_uint8(apply_transform(raw, params)); } } - -// Helper: check if a mapping mode needs the auto-range pre-pass -__device__ __forceinline__ bool needs_auto_range(int mode) { - return (mode == MAPPING_POWER || mode == MAPPING_LOG || - mode == MAPPING_ASINH || mode == MAPPING_LOG1P || - mode == MAPPING_ERF || mode == MAPPING_TANH || - mode == MAPPING_EXP_STRETCH); -} - -// Helper: check if a mapping mode needs the pivot pre-pass -__device__ __forceinline__ bool needs_pivot(int mode) { - return (mode == MAPPING_SUBTRACT); -} - -// Helper: check if mode is the adaptive tail-window pre-pass -__device__ __forceinline__ bool needs_tail_window(int mode) { - return (mode == MAPPING_ADAPTIVE_TAIL_WINDOW); -} - -// Helper: check if mode is the lightweight topk-window pre-pass -__device__ __forceinline__ bool needs_topk_window(int mode) { - return (mode == MAPPING_TOPK_WINDOW); -} diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 46dcdd79..2466f570 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -1,8 +1,20 @@ /** - * Vortex TopK kernel — mirrors topk_slgang_ori.cu structure with additions: - * - bf16 support, flexible radix, mapping/remap modes - * - CSR paged wrapper kernels for vortex integration - * Profiling kernels are in topk_sglang_profile.cu. + * Vortex TopK kernels. + * + * Three production kernels: + * - fast_topk_clean : unmapped baseline (two-stage radix). + * - fast_topk_clean_fused : remap + topk fused (apply_transform + * applied inline in Stage-1 bucketing). + * - TopKRemapOnly_Kernel : standalone element-wise remap pass + * used by the split-phase benchmark. + * + * Profiling kernels (counter collection, histogram collection) live in + * topk_sglang_profile.cu and MUST NOT be used for latency measurements — + * they intentionally write extra diagnostic state to global memory. + * + * Archived / historical kernels: csrc/archived/ (fast_topk_vortex with + * pre-pass modes, TopKOutput_Ori_Kernel with flexible radix_bits, the + * original SGLang reference kernel). */ #include #include @@ -100,22 +112,6 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) constexpr int VORTEX_MAX_TOPK = 2048; -constexpr int COUNTER_THRESHOLD_BIN = 0; -constexpr int COUNTER_NUM_ABOVE = 1; -constexpr int COUNTER_NUM_EQUAL = 2; -constexpr int COUNTER_REMAINING_K = 3; -constexpr int COUNTER_REFINE_ROUNDS = 4; -constexpr int COUNTER_STAGE2_INPUT = 5; -constexpr int NUM_TOPK_COUNTERS = 6; - -template -__device__ __forceinline__ uint16_t convert_to_uintN(float x) { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); - return key >> (16 - BITS); -} - #include "topk_mapping.cuh" @@ -463,80 +459,120 @@ void setup_kernel_smem_once() { } // ====================================================================== -// Ori fast path: zero-overhead topk with no mapping infrastructure. -// Template on RADIX_BITS: 4-10 (16 to 1024 bins). +// Templated clean baseline: identical algorithm to fast_topk_cuda_tl but +// parameterised on ScoreT (float or __nv_bfloat16) for the GQA / paged +// call paths that operate on bf16 attention scores. No mapping, no +// pre-pass — pure two-stage radix topk on fp16 bit-pattern bins. // ====================================================================== -template -__device__ void fast_topk_ori( +template +__device__ void fast_topk_clean( const ScoreT* __restrict__ input, int* __restrict__ index, int row_start, int length, int target_k) { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 1 << RADIX_BITS; - constexpr auto RADIX_PAD = RADIX / 2; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - static_assert(RADIX_BITS >= 4 && RADIX_BITS <= 10, "RADIX_BITS must be 4-10"); - static_assert(RADIX <= BLOCK_SIZE, "RADIX must not exceed BLOCK_SIZE"); + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; - alignas(128) __shared__ int s_histogram_buf[2][RADIX + RADIX_PAD]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; + auto& s_histogram = s_histogram_buf[0]; + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; - auto& s_histogram = s_histogram_buf[0]; - extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + // stage 1: 8-bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(vortex_to_float(input[idx + row_start])); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); - const int tx = threadIdx.x; + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; - // Stage 1: coarse histogram with RADIX bins + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(vortex_to_float(input[idx + row_start]))); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); if (tx < RADIX + 1) s_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = convert_to_uintN(vortex_to_float(input[idx + row_start])); - ::atomicAdd(&s_histogram[bin], 1); + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } } __syncthreads(); + } - const auto run_cumsum = [&] { - for (int i = 0; i < RADIX_BITS; ++i) { - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - // Stage 2 cumsum: always 256 sub-bins (8-bit radix on raw float bits) - const auto run_cumsum_s2 = [&] { - for (int i = 0; i < 8; ++i) { - if (C10_LIKELY(tx < 256)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < 256 - j) { - value += s_histogram_buf[k][tx + j]; - } - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; + // stage 2: refine with 8-bit radix passes on raw fp32 bits +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); run_cumsum(); if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; } __syncthreads(); @@ -544,582 +580,249 @@ __device__ void fast_topk_ori( topk -= s_histogram[threshold_bin + 1]; if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = static_cast(convert_to_uintN(vortex_to_float(input[idx + row_start]))); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; } - __syncthreads(); - return; + } + __syncthreads(); + break; } else { - __syncthreads(); - if (tx < 257) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast(convert_to_uintN(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - } - - // Stage 2: refine with 8-bit radix passes -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); - - run_cumsum_s2(); - if (tx < 256 && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = vortex_to_float(input[idx + row_start]); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < 257) s_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); } - __syncthreads(); + } } + } + __syncthreads(); } + } } // ====================================================================== -// Templated version of fast_topk_cuda_tl with mapping support: -// - ScoreT: float or __nv_bfloat16 -// - StopAfterStage1: return after Stage 1 route/filter (for profiling) -// - WriteCounters: write diagnostic counters to global memory - -// - mapping: configurable value-remapping for Stage 1 bin assignment -template -__device__ void fast_topk_vortex( +// Templated fused kernel: apply_transform(score) -> convert_to_uint8 +// is fused into Stage 1. Stage 2 still uses raw bits for tie-breaking +// (on the *remapped* value, not the original score) — this is a +// benchmarking kernel, the remapped Stage-2 ordering is acceptable. +// No pre-pass, no LUT, no shared-memory mapping state. +// ====================================================================== +template +__device__ void fast_topk_clean_fused( const ScoreT* __restrict__ input, int* __restrict__ index, int row_start, int length, int target_k, - const TopKMappingParams& mapping, - int* counters = nullptr) + const TopKMappingParams mapping) { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; + alignas(128) __shared__ int f_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int f_counter; + alignas(128) __shared__ int f_threshold_bin_id; + alignas(128) __shared__ int f_num_input[2]; - // Shared memory for mapping LUT / quantiles (loaded once per block) - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; + // Shared-memory tables for MAPPING_LUT_CDF / MAPPING_QUANTILE. Loaded + // once at kernel entry and read per element in Stage 1. Other modes + // leave them untouched. + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; - // Auto-range for transform modes (3/4/6/7) - __shared__ float s_range_min, s_range_inv_range; + auto& f_histogram = f_histogram_buf[0]; + extern __shared__ int f_input_idx[][SMEM_INPUT_SIZE]; - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + const int tx = threadIdx.x; - const int tx = threadIdx.x; + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } - // Load mapping tables into shared memory if needed - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); - // Pre-pass: compute per-block min/max of transformed values for linear bucketing. - // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; - // the approximated range may miss extreme outliers but Stage 2 uses raw - // float bits for exact ordering, so correctness is preserved. - if (needs_auto_range(mapping.mode) && !mapping.noscale) { - const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { - float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); - local_min = fminf(local_min, val); - local_max = fmaxf(local_max, val); - } - // Warp-level reduction - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - // Cross-warp reduction via shared memory - __shared__ float s_warp_mins[32], s_warp_maxs[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - if (tx == 0) { - s_range_min = local_min; - float range = local_max - local_min; - s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else if (needs_pivot(mapping.mode)) { - // Pivot pre-pass: compute mean of all elements, store in s_range_min. - // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering - // around the mean helps distribute values more evenly across bins. - float local_sum = 0.0f; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - local_sum += vortex_to_float(input[idx + row_start]); - } - // Warp-level reduction - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - __shared__ float s_warp_sums[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_sums[warp_id] = local_sum; - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_sum = s_warp_sums[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - if (tx == 0) { - s_range_min = local_sum / float(length); // mean as pivot - s_range_inv_range = 0.0f; - } - } - __syncthreads(); - } else if (needs_tail_window(mapping.mode)) { - // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) - // and local_max via a sampled quantile estimator. All 256 coarse bins - // are then allocated to [tau_low, local_max]; scores below tau_low - // collapse into bin 0 via linear_map_to_uint8 clamping. - constexpr int MAX_SAMPLES = 1024; - __shared__ float s_samples[MAX_SAMPLES]; - __shared__ int s_sample_count; - - if (tx == 0) s_sample_count = 0; - __syncthreads(); - - // Compute sampling stride so we collect ~MAX_SAMPLES from the segment - const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; - const int sample_stride = max(desired_stride, 1); - - // Each thread samples elements and finds local_max simultaneously - float local_max = -__FLT_MAX__; - for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { - float val = vortex_to_float(input[idx + row_start]); - local_max = fmaxf(local_max, val); - int slot = ::atomicAdd(&s_sample_count, 1); - if (slot < MAX_SAMPLES) { - s_samples[slot] = val; - } - } + // Stage 1: LUT/QUANTILE do a shared-memory lookup, everything else + // applies the element-wise transform then buckets via convert_to_uint8. + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + ::atomicAdd(&f_histogram[bin], 1); + } + __syncthreads(); - // Reduce local_max across block - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - __shared__ float s_warp_maxs_tw[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_tw[tx]; - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - if (tx == 0) s_warp_maxs_tw[0] = local_max; - } - __syncthreads(); - local_max = s_warp_maxs_tw[0]; - - int nsamp = min(s_sample_count, MAX_SAMPLES); - - // Simple odd-even transposition sort on the sample buffer. - // nsamp <= 1024, and we have 1024 threads, so each thread - // handles one element. O(nsamp) parallel rounds suffice. - __syncthreads(); - if (nsamp >= 2) { - for (int pass = 0; pass < nsamp; ++pass) { - // Even phase: compare (0,1), (2,3), ... - if (tx * 2 + 1 < nsamp) { - int i = tx * 2; - if (s_samples[i] > s_samples[i + 1]) { - float tmp = s_samples[i]; - s_samples[i] = s_samples[i + 1]; - s_samples[i + 1] = tmp; - } - } - __syncthreads(); - // Odd phase: compare (1,2), (3,4), ... - if (tx * 2 + 2 < nsamp) { - int i = tx * 2 + 1; - if (s_samples[i] > s_samples[i + 1]) { - float tmp = s_samples[i]; - s_samples[i] = s_samples[i + 1]; - s_samples[i + 1] = tmp; - } - } - __syncthreads(); - } + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = f_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += f_histogram_buf[k][tx + j]; } + f_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; - // Estimate tau_low = Q(1 - rho * k / n) - if (tx == 0) { - float rho = mapping.power_exp; // reused as tail expansion factor - if (rho <= 0.0f) rho = 4.0f; - int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float frac = 1.0f - rho * float(k) / float(length); - frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 - - float tau_low; - if (nsamp < 4 || frac <= 0.0f) { - // Too few samples or the tail covers everything: full range - tau_low = -__FLT_MAX__; - } else { - float fidx = frac * float(nsamp - 1); - int lo = __float2int_rd(fidx); - lo = min(max(lo, 0), nsamp - 2); - float t = fidx - float(lo); - tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; - } + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[0] = 0; + f_counter = 0; + } + __syncthreads(); - // Fallback: if tau_low >= local_max, use full-range linear mapping - if (tau_low >= local_max) { - // Find the actual minimum from sorted samples - tau_low = (nsamp > 0) ? s_samples[0] : local_max; - } + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - __syncthreads(); - } else if (needs_topk_window(mapping.mode)) { - // Topk-window pre-pass with streaming variance heuristic. - // tau_low = max - rho * sigma * sqrt(2 * log(n/k)) - float local_max = -__FLT_MAX__; - float local_sum = 0.0f, local_sum_sq = 0.0f; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - float val = vortex_to_float(input[idx + row_start]); - local_max = fmaxf(local_max, val); - local_sum += val; - local_sum_sq += val * val; - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); - } - __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { - s_warp_maxs_tw2[warp_id] = local_max; - s_warp_sums_tw2[warp_id] = local_sum; - s_warp_sq_tw2[warp_id] = local_sum_sq; - } - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_tw2[tx]; - local_sum = s_warp_sums_tw2[tx]; - local_sum_sq = s_warp_sq_tw2[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); - } - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float n = float(length); - float mean = local_sum / n; - float var = local_sum_sq / n - mean * mean; - float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; - float ratio = n / fmaxf(float(k), 1.0f); - float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); - float tau_low = local_max - rho * sigma * z; - if (tau_low >= local_max) tau_low = local_max - 1.0f; - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else { - if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } - __syncthreads(); + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast( + compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } } - - // Stage 1: 8-bit coarse histogram (with optional mapping) - // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) - // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. - // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. - constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries - uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); - const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); - - if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range); - ::atomicAdd(&vh_histogram[bin], 1); - if (use_bin_cache) { - bin_cache[idx] = bin; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto bin = static_cast( + compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&f_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); } + } } __syncthreads(); + } - const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; + // stage 2: refine on raw bits of the remapped value +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int f_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = f_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[r_idx ^ 1] = 0; + f_last_remain = topk - f_histogram[tx + 1]; } __syncthreads(); - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_THRESHOLD_BIN] = threshold_bin; - counters[COUNTER_REMAINING_K] = topk; - } + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - int bin; - if (use_bin_cache) { - bin = static_cast(bin_cache[idx]); - } else { - bin = static_cast( - mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); - } - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_NUM_ABOVE] = vh_counter; - counters[COUNTER_NUM_EQUAL] = 0; - counters[COUNTER_REFINE_ROUNDS] = 0; - counters[COUNTER_STAGE2_INPUT] = 0; + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; } - return; + } + __syncthreads(); + break; } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - int bin; - if (use_bin_cache) { - bin = static_cast(bin_cache[idx]); - } else { - bin = static_cast( - mapped_convert_to_uint8(raw_input, mapping, - s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); - } - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_NUM_ABOVE] = vh_counter; - counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; - counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; - } - if (StopAfterStage1) return; - } - - // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) - if constexpr (WriteCounters) { - // Default: all 4 rounds used; overwritten at break if resolved early - if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; - } -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - if constexpr (WriteCounters) { - if (tx == 0 && counters) { - counters[COUNTER_REFINE_ROUNDS] = round + 1; - } + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&f_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; } - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } + } else { + const auto pos = ::atomicAdd(&f_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); } - __syncthreads(); + } } + } + __syncthreads(); } + } } -// Wrapper kernel: one CUDA block per batch*head segment +// Wrapper kernels: one CUDA block per (batch*head) segment. + template __global__ __launch_bounds__(kThreadsPerBlock) -void TopKOutput_Kernel( +void TopKOutput_Clean_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, const int* __restrict__ sparse_kv_indptr, @@ -1127,37 +830,34 @@ void TopKOutput_Kernel( int* __restrict__ sparse_kv_indices, const int topk_val, const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) + const int page_reserved_eos) { - const int bx = blockIdx.x; + const int bx = blockIdx.x; - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex(score_blk, s_indices, 0, nblk, topk_val, mapping); - __syncthreads(); + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_clean(score_blk, s_indices, 0, nblk, topk_val); + __syncthreads(); - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } } -// Ori fast-path wrapper: zero mapping overhead, flexible radix -template +template __global__ __launch_bounds__(kThreadsPerBlock) -void TopKOutput_Ori_Kernel( +void TopKOutput_Fused_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, const int* __restrict__ sparse_kv_indptr, @@ -1165,54 +865,60 @@ void TopKOutput_Ori_Kernel( int* __restrict__ sparse_kv_indices, const int topk_val, const int page_reserved_bos, - const int page_reserved_eos) + const int page_reserved_eos, + const TopKMappingParams mapping) { - const int bx = blockIdx.x; + const int bx = blockIdx.x; - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_ori(score_blk, s_indices, 0, nblk, topk_val); - __syncthreads(); + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_clean_fused(score_blk, s_indices, 0, nblk, topk_val, mapping); + __syncthreads(); - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } } -// Helper: launch TopKOutput_Ori_Kernel with radix_bits dispatch +// Remap-only kernel: applies the element-wise transform to each score +// in the [dense_kv_indptr[b] + reserved_bos, dense_kv_indptr[b+1] - reserved_eos) +// range and writes the result into a float32 output tensor. Used by +// the split-phase benchmark (remap → unmapped topk). template -void launch_ori_kernel( - const ScoreT* score, const int* dense_kv_indptr, const int* sparse_kv_indptr, - const int* dense_kv_indices, int* sparse_kv_indices, - int topk_val, int reserved_bos, int reserved_eos, - int radix_bits, dim3 nblks, dim3 nthreads, cudaStream_t stream) +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKRemapOnly_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + float* __restrict__ remapped, + const int page_reserved_bos, + const int page_reserved_eos, + const TopKMappingParams mapping) { - #define LAUNCH_ORI(BITS) \ - setup_kernel_smem_once, kSmem>(); \ - TopKOutput_Ori_Kernel<<>>( \ - score, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, \ - topk_val, reserved_bos, reserved_eos) - switch (radix_bits) { - case 4: LAUNCH_ORI(4); break; - case 5: LAUNCH_ORI(5); break; - case 6: LAUNCH_ORI(6); break; - case 7: LAUNCH_ORI(7); break; - case 9: LAUNCH_ORI(9); break; - case 10: LAUNCH_ORI(10); break; - default: LAUNCH_ORI(8); break; - } - #undef LAUNCH_ORI + const int bx = blockIdx.x; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= 0) return; + + const ScoreT* __restrict__ score_blk = score + start; + float* __restrict__ remap_blk = remapped + start; + + for (int i = tx; i < nblk; i += kThreadsPerBlock) { + remap_blk[i] = apply_transform(vortex_to_float(score_blk[i]), mapping); + } } } // namespace @@ -1331,9 +1037,68 @@ void fast_topk_transform_ragged_interface( } // ====================================================================== -// Vortex host entry point — same interface as topk_output in topk.cu +// Vortex host entry point — unmapped baseline topk (no remap). +// This is the "original topk kernel" used as the benchmarking baseline. // ====================================================================== void topk_output_sglang( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Clean_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Clean_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos); + } else { + TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output kernel failed: ", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Fused remap + topk host entry. Applies apply_transform(score, mapping) +// inline inside the Stage-1 histogram build — single kernel launch, +// single pass over the score tensor. +// ====================================================================== +void topk_output_sglang_fused( const at::Tensor& x, const at::Tensor& dense_kv_indptr, const at::Tensor& sparse_kv_indptr, @@ -1347,39 +1112,35 @@ void topk_output_sglang( const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale, - const int64_t sample_stride, - const int64_t radix_bits) + std::optional mapping_quantiles) { TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output: topk_val (", topk_val, + "topk_output_sglang_fused: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, - "topk_output: radix_bits must be 4-10, got ", radix_bits); - // Build mapping params from optional tensors + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + TopKMappingParams mapping{}; mapping.mode = static_cast(mapping_mode); mapping.power_exp = static_cast(mapping_power); mapping.lut = nullptr; mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = static_cast(sample_stride); - mapping.target_k = static_cast(topk_val); - if (mapping_lut.has_value()) { const auto& lut = mapping_lut.value(); CHECK_CUDA(lut); TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); + "mapping_lut must be a 1D uint8 tensor of size 256"); mapping.lut = lut.data_ptr(); } if (mapping_quantiles.has_value()) { const auto& q = mapping_quantiles.value(); CHECK_CUDA(q); TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); + "mapping_quantiles must be a 1D float32 tensor of size 256"); mapping.quantiles = q.data_ptr(); } @@ -1387,111 +1148,81 @@ void topk_output_sglang( dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - // Fast path for mode 0 (MAPPING_NONE): use ori kernel with zero mapping overhead - if (mapping_mode == MAPPING_NONE) { - if (x.scalar_type() == at::ScalarType::BFloat16) { - launch_ori_kernel<__nv_bfloat16>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, - radix_bits, nblks, nthreads, stream); - } else if (x.scalar_type() == at::ScalarType::Float) { - launch_ori_kernel( - x.data_ptr(), - dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, - radix_bits, nblks, nthreads, stream); - } else { - TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); - } + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Fused_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKOutput_Fused_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); } else { - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, "topk_output: unsupported dtype ", x.scalar_type()); - } + TORCH_CHECK(false, "topk_output_sglang_fused: unsupported dtype ", x.scalar_type()); } const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, - "topk_output kernel failed: ", ::cudaGetErrorString(result)); + "topk_output_sglang_fused kernel failed: ", ::cudaGetErrorString(result)); } // ====================================================================== -// Explicit ori baseline entry point — always uses the ori fast path +// Standalone remap kernel. Writes apply_transform(score) into a +// float32 output buffer without running topk. Used by the split-phase +// benchmark (remap → unmapped topk) to measure each phase independently. // ====================================================================== -void topk_output_sglang_ori( +void topk_remap_only( const at::Tensor& x, const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, + at::Tensor& remapped, // float32, same numel as x const int64_t eff_batch_size, - const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, - const int64_t max_num_pages, - const int64_t radix_bits) + const int64_t mapping_mode, + const double mapping_power) { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output_sglang_ori: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - TORCH_CHECK(radix_bits >= 4 && radix_bits <= 10, - "topk_output_sglang_ori: radix_bits must be 4-10, got ", radix_bits); - CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); - CHECK_CUDA(sparse_kv_indptr); - CHECK_CUDA(dense_kv_indices); - CHECK_CUDA(sparse_kv_indices); + CHECK_CUDA(remapped); + TORCH_CHECK(remapped.scalar_type() == at::ScalarType::Float, + "remapped output must be float32"); + + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); if (x.scalar_type() == at::ScalarType::BFloat16) { - launch_ori_kernel<__nv_bfloat16>( + TopKRemapOnly_Kernel<__nv_bfloat16><<>>( reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, - radix_bits, nblks, nthreads, stream); + dense_kv_indptr.data_ptr(), + remapped.data_ptr(), + reserved_bos, reserved_eos, mapping); } else if (x.scalar_type() == at::ScalarType::Float) { - launch_ori_kernel( + TopKRemapOnly_Kernel<<>>( x.data_ptr(), - dense_kv_indptr.data_ptr(), sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, - radix_bits, nblks, nthreads, stream); + dense_kv_indptr.data_ptr(), + remapped.data_ptr(), + reserved_bos, reserved_eos, mapping); } else { - TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + TORCH_CHECK(false, "topk_remap_only: unsupported dtype ", x.scalar_type()); } const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, - "topk_output_sglang_ori kernel failed: ", ::cudaGetErrorString(result)); + "topk_remap_only kernel failed: ", ::cudaGetErrorString(result)); } diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu index 6aeac4b8..adba2d03 100644 --- a/csrc/topk_sglang_profile.cu +++ b/csrc/topk_sglang_profile.cu @@ -98,7 +98,13 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) return __bfloat162float(x); } + constexpr int VORTEX_MAX_TOPK = 2048; + +// Diagnostic counters written by the profiling kernel. These kernels are +// NOT used for latency measurements — they intentionally add global-memory +// writes that distort timings. Latency is measured against the clean +// production kernels in topk_sglang.cu. constexpr int COUNTER_THRESHOLD_BIN = 0; constexpr int COUNTER_NUM_ABOVE = 1; constexpr int COUNTER_NUM_EQUAL = 2; @@ -109,1025 +115,403 @@ constexpr int NUM_TOPK_COUNTERS = 6; #include "topk_mapping.cuh" -// - mapping: configurable value-remapping for Stage 1 bin assignment -template -__device__ void fast_topk_vortex( +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); +} + +// ====================================================================== +// Profiling variant of fast_topk_clean_fused that writes diagnostic +// counters at the end of Stage 1 and at each Stage 2 early-exit. +// Shape / semantics identical to the production kernel, with one extra +// global-memory write pass at the end of each stage. Do not use for +// latency measurements. +// ====================================================================== +template +__device__ void fast_topk_profile( const ScoreT* __restrict__ input, int* __restrict__ index, int row_start, int length, int target_k, - const TopKMappingParams& mapping, - int* counters = nullptr) + const TopKMappingParams mapping, + int* __restrict__ counters) // [NUM_TOPK_COUNTERS] { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int vh_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int vh_counter; - alignas(128) __shared__ int vh_threshold_bin_id; - alignas(128) __shared__ int vh_num_input[2]; + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - // Shared memory for mapping LUT / quantiles (loaded once per block) - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; + alignas(128) __shared__ int p_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int p_counter; + alignas(128) __shared__ int p_threshold_bin_id; + alignas(128) __shared__ int p_num_input[2]; - // Auto-range for transform modes (3/4/6/7) - __shared__ float s_range_min, s_range_inv_range; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; - auto& vh_histogram = vh_histogram_buf[0]; - extern __shared__ int vh_input_idx[][SMEM_INPUT_SIZE]; + auto& p_histogram = p_histogram_buf[0]; + extern __shared__ int p_input_idx[][SMEM_INPUT_SIZE]; - const int tx = threadIdx.x; + const int tx = threadIdx.x; - // Load mapping tables into shared memory if needed - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } - - // Pre-pass: compute per-block min/max of transformed values for linear bucketing. - // sample_stride > 1 reduces pre-pass cost by scanning every Nth element; - // the approximated range may miss extreme outliers but Stage 2 uses raw - // float bits for exact ordering, so correctness is preserved. - if (needs_auto_range(mapping.mode) && !mapping.noscale) { - const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx * stride; idx < length; idx += BLOCK_SIZE * stride) { - float val = apply_transform(vortex_to_float(input[idx + row_start]), mapping); - local_min = fminf(local_min, val); - local_max = fmaxf(local_max, val); - } - // Warp-level reduction - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - // Cross-warp reduction via shared memory - __shared__ float s_warp_mins[32], s_warp_maxs[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - if (tx == 0) { - s_range_min = local_min; - float range = local_max - local_min; - s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else if (needs_pivot(mapping.mode)) { - // Pivot pre-pass: compute mean of all elements, store in s_range_min. - // MAPPING_SUBTRACT uses convert_to_uint8(x - range_min), so centering - // around the mean helps distribute values more evenly across bins. - float local_sum = 0.0f; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - local_sum += vortex_to_float(input[idx + row_start]); - } - // Warp-level reduction - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - __shared__ float s_warp_sums[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_sums[warp_id] = local_sum; - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_sum = s_warp_sums[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - if (tx == 0) { - s_range_min = local_sum / float(length); // mean as pivot - s_range_inv_range = 0.0f; - } - } - __syncthreads(); - } else if (needs_tail_window(mapping.mode)) { - // Adaptive tail-window pre-pass: estimate tau_low = Q(1 - rho*k/n) - // and local_max via a sampled quantile estimator. All 256 coarse bins - // are then allocated to [tau_low, local_max]; scores below tau_low - // collapse into bin 0 via linear_map_to_uint8 clamping. - constexpr int MAX_SAMPLES = 1024; - __shared__ float s_samples[MAX_SAMPLES]; - __shared__ int s_sample_count; - - if (tx == 0) s_sample_count = 0; - __syncthreads(); - - // Compute sampling stride so we collect ~MAX_SAMPLES from the segment - const int desired_stride = (length + MAX_SAMPLES - 1) / MAX_SAMPLES; - const int sample_stride = max(desired_stride, 1); - - // Each thread samples elements and finds local_max simultaneously - float local_max = -__FLT_MAX__; - for (int idx = tx * sample_stride; idx < length; idx += BLOCK_SIZE * sample_stride) { - float val = vortex_to_float(input[idx + row_start]); - local_max = fmaxf(local_max, val); - int slot = ::atomicAdd(&s_sample_count, 1); - if (slot < MAX_SAMPLES) { - s_samples[slot] = val; - } - } + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; + __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } - // Reduce local_max across block - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - __shared__ float s_warp_maxs_tw[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_maxs_tw[warp_id] = local_max; - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_tw[tx]; - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - if (tx == 0) s_warp_maxs_tw[0] = local_max; - } - __syncthreads(); - local_max = s_warp_maxs_tw[0]; + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); - int nsamp = min(s_sample_count, MAX_SAMPLES); + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + ::atomicAdd(&p_histogram[bin], 1); + } + __syncthreads(); - // Simple odd-even transposition sort on the sample buffer. - // nsamp <= 1024, and we have 1024 threads, so each thread - // handles one element. O(nsamp) parallel rounds suffice. - __syncthreads(); - if (nsamp >= 2) { - for (int pass = 0; pass < nsamp; ++pass) { - // Even phase: compare (0,1), (2,3), ... - if (tx * 2 + 1 < nsamp) { - int i = tx * 2; - if (s_samples[i] > s_samples[i + 1]) { - float tmp = s_samples[i]; - s_samples[i] = s_samples[i + 1]; - s_samples[i + 1] = tmp; - } - } - __syncthreads(); - // Odd phase: compare (1,2), (3,4), ... - if (tx * 2 + 2 < nsamp) { - int i = tx * 2 + 1; - if (s_samples[i] > s_samples[i + 1]) { - float tmp = s_samples[i]; - s_samples[i] = s_samples[i + 1]; - s_samples[i + 1] = tmp; - } - } - __syncthreads(); - } - } + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = p_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += p_histogram_buf[k][tx + j]; + } + p_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; - // Estimate tau_low = Q(1 - rho * k / n) - if (tx == 0) { - float rho = mapping.power_exp; // reused as tail expansion factor - if (rho <= 0.0f) rho = 4.0f; - int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float frac = 1.0f - rho * float(k) / float(length); - frac = fmaxf(frac, 0.0f); // clamp: never go below rank 0 + run_cumsum(); + if (tx < RADIX && p_histogram[tx] > topk && p_histogram[tx + 1] <= topk) { + p_threshold_bin_id = tx; + p_num_input[0] = 0; + p_counter = 0; + } + __syncthreads(); - float tau_low; - if (nsamp < 4 || frac <= 0.0f) { - // Too few samples or the tail covers everything: full range - tau_low = -__FLT_MAX__; - } else { - float fidx = frac * float(nsamp - 1); - int lo = __float2int_rd(fidx); - lo = min(max(lo, 0), nsamp - 2); - float t = fidx - float(lo); - tau_low = s_samples[lo] * (1.0f - t) + s_samples[lo + 1] * t; - } + const int threshold_bin_0 = p_threshold_bin_id; + const int threshold_bin_size = p_histogram[threshold_bin_0]; // pre-reset count + topk -= p_histogram[threshold_bin_0 + 1]; - // Fallback: if tau_low >= local_max, use full-range linear mapping - if (tau_low >= local_max) { - // Find the actual minimum from sorted samples - tau_low = (nsamp > 0) ? s_samples[0] : local_max; - } + if (tx == 0 && counters) { + counters[COUNTER_THRESHOLD_BIN] = threshold_bin_0; + counters[COUNTER_NUM_EQUAL] = threshold_bin_size; + counters[COUNTER_REMAINING_K] = topk; + } - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - __syncthreads(); - } else if (needs_topk_window(mapping.mode)) { - // Topk-window pre-pass with streaming variance heuristic. - float local_max = -__FLT_MAX__; - float local_sum = 0.0f, local_sum_sq = 0.0f; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - float val = vortex_to_float(input[idx + row_start]); - local_max = fmaxf(local_max, val); - local_sum += val; - local_sum_sq += val * val; - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); - } - __shared__ float s_warp_maxs_tw2[32], s_warp_sums_tw2[32], s_warp_sq_tw2[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { - s_warp_maxs_tw2[warp_id] = local_max; - s_warp_sums_tw2[warp_id] = local_sum; - s_warp_sq_tw2[warp_id] = local_sum_sq; - } - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_tw2[tx]; - local_sum = s_warp_sums_tw2[tx]; - local_sum_sq = s_warp_sq_tw2[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - local_sum_sq += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq, offset); - } - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = (mapping.target_k > 0) ? mapping.target_k : target_k; - float n = float(length); - float mean = local_sum / n; - float var = local_sum_sq / n - mean * mean; - float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; - float ratio = n / fmaxf(float(k), 1.0f); - float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); - float tau_low = local_max - rho * sigma * z; - if (tau_low >= local_max) tau_low = local_max - 1.0f; - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else { - if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } - __syncthreads(); + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const auto bin = static_cast( + compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin_0) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } } - - // Stage 1: 8-bit coarse histogram (with optional mapping) - // Bin cache: store computed bins in vh_input_idx[1] (reinterpreted as uint8_t*) - // to avoid recomputing mapped_convert_to_uint8 in the route/filter pass. - // vh_input_idx[1] is unused until Stage 2 double-buffering starts after route. - constexpr int BIN_CACHE_CAPACITY = SMEM_INPUT_SIZE * static_cast(sizeof(int)); // uint8 entries - uint8_t* bin_cache = reinterpret_cast(vh_input_idx[1]); - const bool use_bin_cache = (length <= BIN_CACHE_CAPACITY); - - if (tx < RADIX + 1) vh_histogram[tx] = 0; + __syncthreads(); + if (tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = p_counter; + counters[COUNTER_REFINE_ROUNDS] = 0; + counters[COUNTER_STAGE2_INPUT] = 0; + } + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) p_histogram[tx] = 0; __syncthreads(); for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto bin = mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range); - ::atomicAdd(&vh_histogram[bin], 1); - if (use_bin_cache) { - bin_cache[idx] = bin; - } + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto bin = static_cast( + compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + if (bin > threshold_bin_0) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin_0) { + const auto pos = ::atomicAdd(&p_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + p_input_idx[0][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&p_histogram[sub_bin], 1); + } + } } __syncthreads(); + if (tx == 0 && counters) { + counters[COUNTER_NUM_ABOVE] = p_counter; + counters[COUNTER_STAGE2_INPUT] = p_num_input[0]; + } + } - const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = vh_histogram_buf[k][tx]; - if (tx < RADIX - j) { - value += vh_histogram_buf[k][tx + j]; - } - vh_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; + // Stage 2 refinement (4 rounds max). Default rounds=4, overwritten on exit. + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int p_last_remain; + const auto r_idx = round % 2; + const auto _raw_num_input = p_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[0] = 0; - vh_counter = 0; + if (tx < RADIX && p_histogram[tx] > topk && p_histogram[tx + 1] <= topk) { + p_threshold_bin_id = tx; + p_num_input[r_idx ^ 1] = 0; + p_last_remain = topk - p_histogram[tx + 1]; } __syncthreads(); - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_THRESHOLD_BIN] = threshold_bin; - counters[COUNTER_REMAINING_K] = topk; - } + const auto threshold_bin = p_threshold_bin_id; + topk -= p_histogram[threshold_bin + 1]; if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - int bin; - if (use_bin_cache) { - bin = static_cast(bin_cache[idx]); - } else { - bin = static_cast( - mapped_convert_to_uint8( - vortex_to_float(input[idx + row_start]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); - } - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_NUM_ABOVE] = vh_counter; - counters[COUNTER_NUM_EQUAL] = 0; - counters[COUNTER_REFINE_ROUNDS] = 0; - counters[COUNTER_STAGE2_INPUT] = 0; - } - return; + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = p_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = round + 1; + break; } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const auto raw_input = vortex_to_float(input[idx + row_start]); - int bin; - if (use_bin_cache) { - bin = static_cast(bin_cache[idx]); - } else { - bin = static_cast( - mapped_convert_to_uint8(raw_input, mapping, - s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range)); - } - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&vh_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[0][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - __syncthreads(); - if (WriteCounters && tx == 0 && counters) { - counters[COUNTER_NUM_ABOVE] = vh_counter; - counters[COUNTER_NUM_EQUAL] = vh_num_input[0]; - counters[COUNTER_STAGE2_INPUT] = vh_num_input[0]; - } - if (StopAfterStage1) return; - } - - // Stage 2: refine with 8-bit radix passes (unchanged — uses raw float bits) - if constexpr (WriteCounters) { - // Default: all 4 rounds used; overwritten at break if resolved early - if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; - } -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - __shared__ int vh_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = vh_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) - ? _raw_num_input - : int(SMEM_INPUT_SIZE); - - run_cumsum(); - if (tx < RADIX && vh_histogram[tx] > topk && vh_histogram[tx + 1] <= topk) { - vh_threshold_bin_id = tx; - vh_num_input[r_idx ^ 1] = 0; - vh_last_remain = topk - vh_histogram[tx + 1]; - } - __syncthreads(); - - const auto threshold_bin = vh_threshold_bin_id; - topk -= vh_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32( - vortex_to_float(input[idx + row_start])) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } + __syncthreads(); + if (tx < RADIX + 1) p_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = p_input_idx[r_idx][i]; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform(raw, mapping); + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&p_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&p_last_remain, -1); + if (pos > 0) { + index[target_k - pos] = idx; } - __syncthreads(); - if constexpr (WriteCounters) { - if (tx == 0 && counters) { - counters[COUNTER_REFINE_ROUNDS] = round + 1; - } + } else { + const auto pos = ::atomicAdd(&p_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + p_input_idx[r_idx ^ 1][pos] = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&p_histogram[sub_bin], 1); } - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) vh_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = vh_input_idx[r_idx][i]; - const auto raw_input = vortex_to_float(input[idx + row_start]); - const auto offset = 24 - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&vh_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == 3) { - const auto pos = ::atomicAdd(&vh_last_remain, -1); - if (pos > 0) { - index[target_k - pos] = idx; - } - } else { - const auto pos = ::atomicAdd(&vh_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - vh_input_idx[r_idx ^ 1][pos] = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&vh_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); + } } + } + __syncthreads(); } + } } -template -void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { -#ifdef USE_ROCM - // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, - // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing - // a function pointer directly, so cast explicitly. - return ::cudaFuncSetAttribute( - reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#else - // CUDA: keep original behavior (no cast needed). - return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#endif - }(); - TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); -} - -// ====================================================================== +// Wrapper: one block per (batch*head) segment. Writes counters per +// segment into a [eff_batch_size, NUM_TOPK_COUNTERS] int32 tensor. template __global__ __launch_bounds__(kThreadsPerBlock) -void TopKStage1_Kernel( +void TopKProfileCounters_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, const int* __restrict__ sparse_kv_indptr, const int* __restrict__ dense_kv_indices, int* __restrict__ sparse_kv_indices, + int* __restrict__ counters, const int topk_val, const int page_reserved_bos, const int page_reserved_eos, const TopKMappingParams mapping) { - const int bx = blockIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; - - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex( - score_blk, s_indices, 0, nblk, topk_val, mapping); - __syncthreads(); - - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } + const int bx = blockIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; + if (nblk <= topk_val) return; + + const ScoreT* __restrict__ score_blk = score + start; + const int* __restrict__ idx_blk = dense_kv_indices + start; + int* __restrict__ out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + + __shared__ int s_indices[VORTEX_MAX_TOPK]; + fast_topk_profile( + score_blk, s_indices, 0, nblk, topk_val, mapping, + counters + bx * NUM_TOPK_COUNTERS); + __syncthreads(); + + const int tx = threadIdx.x; + for (int i = tx; i < topk_val; i += kThreadsPerBlock) { + out_blk[i] = idx_blk[s_indices[i]]; + } } -// ====================================================================== -// Profiling counters kernel: runs full pipeline + writes diagnostic -// counters to a separate global-memory tensor -// ====================================================================== +// Histogram-only profiling kernel: builds a 256-bin histogram of the +// remapped bins for each segment. Purely diagnostic — never timed. template __global__ __launch_bounds__(kThreadsPerBlock) -void TopKCounters_Kernel( +void TopKProfileHistogram_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - int* __restrict__ counters, // [eff_batch_size, NUM_TOPK_COUNTERS] - const int topk_val, + int* __restrict__ histograms, // [eff_batch_size, 256] const int page_reserved_bos, const int page_reserved_eos, const TopKMappingParams mapping) { - const int bx = blockIdx.x; + constexpr auto RADIX = 256; + constexpr auto BLOCK_SIZE = kThreadsPerBlock; + __shared__ int s_histogram[RADIX]; + __shared__ uint8_t s_mapping_lut[256]; + __shared__ float s_mapping_quantiles[256]; - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; - if (nblk <= topk_val) return; + const int bx = blockIdx.x; + const int tx = threadIdx.x; - const ScoreT* __restrict__ score_blk = score + start; - const int* __restrict__ idx_blk = dense_kv_indices + start; - int* __restrict__ out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int nblk = end - start; - __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_vortex( - score_blk, s_indices, 0, nblk, topk_val, mapping, - counters + bx * NUM_TOPK_COUNTERS); + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { + if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; __syncthreads(); + } + if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { + if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; + __syncthreads(); + } - // Remap position indices -> page indices via dense_kv_indices - const int tx = threadIdx.x; - for (int i = tx; i < topk_val; i += kThreadsPerBlock) { - out_blk[i] = idx_blk[s_indices[i]]; - } -} - -// ====================================================================== -// Profiling histogram kernel: runs only Stage 1 and returns per-segment -// 256-bin histograms for distribution analysis -// ====================================================================== -template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKHistogram_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - int* __restrict__ histograms, // [eff_batch_size, 256] - const int page_reserved_bos, - const int page_reserved_eos, - const TopKMappingParams mapping) -{ - constexpr auto RADIX = 256; - constexpr auto BLOCK_SIZE = kThreadsPerBlock; - __shared__ int s_histogram[RADIX]; - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; - __shared__ float s_range_min, s_range_inv_range; - - const int bx = blockIdx.x; - const int tx = threadIdx.x; - - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int nblk = end - start; + if (tx < RADIX) s_histogram[tx] = 0; + __syncthreads(); + if (nblk > 0) { const ScoreT* __restrict__ score_blk = score + start; - - // Load mapping tables into shared memory if needed - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); + for (int i = tx; i < nblk; i += BLOCK_SIZE) { + const float raw = vortex_to_float(score_blk[i]); + const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + ::atomicAdd(&s_histogram[bin], 1); } + } + __syncthreads(); - // Pre-pass: compute per-block min/max for transform modes (supports sampled stride) - if (needs_auto_range(mapping.mode) && !mapping.noscale) { - const int stride = (mapping.sample_stride > 1) ? mapping.sample_stride : 1; - float local_min = __FLT_MAX__, local_max = -__FLT_MAX__; - for (int idx = tx * stride; idx < nblk; idx += BLOCK_SIZE * stride) { - float val = apply_transform(vortex_to_float(score_blk[idx]), mapping); - local_min = fminf(local_min, val); - local_max = fmaxf(local_max, val); - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - __shared__ float s_warp_mins[32], s_warp_maxs[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { s_warp_mins[warp_id] = local_min; s_warp_maxs[warp_id] = local_max; } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_min = s_warp_mins[tx]; local_max = s_warp_maxs[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_min = fminf(local_min, __shfl_xor_sync(0xFFFFFFFF, local_min, offset)); - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - } - if (tx == 0) { - s_range_min = local_min; - float range = local_max - local_min; - s_range_inv_range = (range > 0.0f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else if (needs_pivot(mapping.mode)) { - // Pivot pre-pass: compute mean for MAPPING_SUBTRACT - float local_sum = 0.0f; - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - local_sum += vortex_to_float(score_blk[idx]); - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - __shared__ float s_warp_sums_h[32]; - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_sums_h[warp_id] = local_sum; - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_sum = s_warp_sums_h[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_sum += __shfl_xor_sync(0xFFFFFFFF, local_sum, offset); - } - if (tx == 0) { - s_range_min = local_sum / float(nblk); - s_range_inv_range = 0.0f; - } - } - __syncthreads(); - } else if (needs_tail_window(mapping.mode)) { - // Adaptive tail-window pre-pass (histogram kernel variant) - constexpr int MAX_SAMPLES_H = 1024; - __shared__ float s_samples_h[MAX_SAMPLES_H]; - __shared__ int s_sample_count_h; - - if (tx == 0) s_sample_count_h = 0; - __syncthreads(); - - const int desired_stride = (nblk + MAX_SAMPLES_H - 1) / MAX_SAMPLES_H; - const int sample_stride_h = max(desired_stride, 1); - - float local_max = -__FLT_MAX__; - for (int idx = tx * sample_stride_h; idx < nblk; idx += BLOCK_SIZE * sample_stride_h) { - float val = vortex_to_float(score_blk[idx]); - local_max = fmaxf(local_max, val); - int slot = ::atomicAdd(&s_sample_count_h, 1); - if (slot < MAX_SAMPLES_H) s_samples_h[slot] = val; - } - - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - __shared__ float s_warp_maxs_h[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) s_warp_maxs_h[warp_id] = local_max; - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max = s_warp_maxs_h[tx]; - for (int offset = 16; offset > 0; offset >>= 1) - local_max = fmaxf(local_max, __shfl_xor_sync(0xFFFFFFFF, local_max, offset)); - if (tx == 0) s_warp_maxs_h[0] = local_max; - } - __syncthreads(); - local_max = s_warp_maxs_h[0]; - - int nsamp = min(s_sample_count_h, MAX_SAMPLES_H); - - __syncthreads(); - if (nsamp >= 2) { - for (int pass = 0; pass < nsamp; ++pass) { - if (tx * 2 + 1 < nsamp) { - int i = tx * 2; - if (s_samples_h[i] > s_samples_h[i + 1]) { - float tmp = s_samples_h[i]; - s_samples_h[i] = s_samples_h[i + 1]; - s_samples_h[i + 1] = tmp; - } - } - __syncthreads(); - if (tx * 2 + 2 < nsamp) { - int i = tx * 2 + 1; - if (s_samples_h[i] > s_samples_h[i + 1]) { - float tmp = s_samples_h[i]; - s_samples_h[i] = s_samples_h[i + 1]; - s_samples_h[i + 1] = tmp; - } - } - __syncthreads(); - } - } - - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = mapping.target_k; - float frac = (k > 0 && nblk > 0) ? 1.0f - rho * float(k) / float(nblk) : 0.0f; - frac = fmaxf(frac, 0.0f); - - float tau_low; - if (nsamp < 4 || frac <= 0.0f) { - tau_low = -__FLT_MAX__; - } else { - float fidx = frac * float(nsamp - 1); - int lo = __float2int_rd(fidx); - lo = min(max(lo, 0), nsamp - 2); - float t = fidx - float(lo); - tau_low = s_samples_h[lo] * (1.0f - t) + s_samples_h[lo + 1] * t; - } - - if (tau_low >= local_max) { - tau_low = (nsamp > 0) ? s_samples_h[0] : local_max; - } - - float range = local_max - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - __syncthreads(); - } else if (needs_topk_window(mapping.mode)) { - // Topk-window pre-pass with streaming variance (histogram kernel variant) - float local_max_h = -__FLT_MAX__; - float local_sum_h = 0.0f, local_sum_sq_h = 0.0f; - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - float val = vortex_to_float(score_blk[idx]); - local_max_h = fmaxf(local_max_h, val); - local_sum_h += val; - local_sum_sq_h += val * val; - } - for (int offset = 16; offset > 0; offset >>= 1) { - local_max_h = fmaxf(local_max_h, __shfl_xor_sync(0xFFFFFFFF, local_max_h, offset)); - local_sum_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_h, offset); - local_sum_sq_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq_h, offset); - } - __shared__ float s_warp_maxs_tw3[32], s_warp_sums_tw3[32], s_warp_sq_tw3[32]; - { - int warp_id = tx >> 5, lane_id = tx & 31; - if (lane_id == 0) { - s_warp_maxs_tw3[warp_id] = local_max_h; - s_warp_sums_tw3[warp_id] = local_sum_h; - s_warp_sq_tw3[warp_id] = local_sum_sq_h; - } - } - __syncthreads(); - if (tx < (BLOCK_SIZE >> 5)) { - local_max_h = s_warp_maxs_tw3[tx]; - local_sum_h = s_warp_sums_tw3[tx]; - local_sum_sq_h = s_warp_sq_tw3[tx]; - for (int offset = 16; offset > 0; offset >>= 1) { - local_max_h = fmaxf(local_max_h, __shfl_xor_sync(0xFFFFFFFF, local_max_h, offset)); - local_sum_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_h, offset); - local_sum_sq_h += __shfl_xor_sync(0xFFFFFFFF, local_sum_sq_h, offset); - } - if (tx == 0) { - float rho = mapping.power_exp; - if (rho <= 0.0f) rho = 4.0f; - int k = mapping.target_k; - float n = float(nblk); - float mean = local_sum_h / n; - float var = local_sum_sq_h / n - mean * mean; - float sigma = (var > 0.0f) ? sqrtf(var) : 0.0f; - float ratio = n / fmaxf(float(k), 1.0f); - float z = sqrtf(2.0f * __logf(fmaxf(ratio, 1.0f))); - float tau_low = local_max_h - rho * sigma * z; - if (tau_low >= local_max_h) tau_low = local_max_h - 1.0f; - float range = local_max_h - tau_low; - s_range_min = tau_low; - s_range_inv_range = (range > 1e-10f) ? 255.0f / range : 0.0f; - } - } - __syncthreads(); - } else { - if (tx == 0) { s_range_min = 0.0f; s_range_inv_range = 0.0f; } - __syncthreads(); - } - - // Initialize shared histogram - if (tx < RADIX) s_histogram[tx] = 0; - __syncthreads(); - - // Build histogram over the segment with mapping - for (int idx = tx; idx < nblk; idx += BLOCK_SIZE) { - const auto bin = mapped_convert_to_uint8( - vortex_to_float(score_blk[idx]), - mapping, s_mapping_lut, s_mapping_quantiles, - s_range_min, s_range_inv_range); - ::atomicAdd(&s_histogram[bin], 1); - } - __syncthreads(); - - // Write to global memory - int* __restrict__ out = histograms + bx * RADIX; - if (tx < RADIX) { - out[tx] = s_histogram[tx]; - } + int* __restrict__ out = histograms + bx * RADIX; + if (tx < RADIX) out[tx] = s_histogram[tx]; } } // namespace #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") -// ====================================================================== -// Profiling: collect per-segment 256-bin histograms of Stage 1 bins -// ====================================================================== -void topk_profile_histogram( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - at::Tensor& histograms, - const int64_t eff_batch_size, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale, - const int64_t topk_val, - const int64_t sample_stride) -{ - CHECK_CUDA(x); - CHECK_CUDA(dense_kv_indptr); - CHECK_CUDA(histograms); - TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size - && histograms.size(1) == 256, - "histograms must be [eff_batch_size, 256]"); - TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, - "histograms must be int32"); - - // Build mapping params - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = static_cast(sample_stride); - mapping.target_k = static_cast(topk_val); - - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - TopKHistogram_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - histograms.data_ptr(), - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - TopKHistogram_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - histograms.data_ptr(), - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_histogram: unsupported dtype ", - x.scalar_type()); - } - - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); -} - -// Helper: build TopKMappingParams from host arguments static TopKMappingParams build_mapping_params( int64_t mapping_mode, double mapping_power, std::optional& mapping_lut, - std::optional& mapping_quantiles, - bool mapping_noscale = false, - int sample_stride = 1, - int target_k = 0) + std::optional& mapping_quantiles) { - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - mapping.noscale = mapping_noscale; - mapping.sample_stride = sample_stride; - mapping.target_k = target_k; - - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } - return mapping; + TopKMappingParams m{}; + m.mode = static_cast(mapping_mode); + m.power_exp = static_cast(mapping_power); + m.lut = nullptr; + m.quantiles = nullptr; + if (mapping_lut.has_value()) { + const auto& lut = mapping_lut.value(); + TORCH_CHECK(lut.is_cuda(), "mapping_lut must be a CUDA tensor"); + TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, + "mapping_lut must be a 1D uint8 tensor of size 256"); + m.lut = lut.data_ptr(); + } + if (mapping_quantiles.has_value()) { + const auto& q = mapping_quantiles.value(); + TORCH_CHECK(q.is_cuda(), "mapping_quantiles must be a CUDA tensor"); + TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, + "mapping_quantiles must be a 1D float32 tensor of size 256"); + m.quantiles = q.data_ptr(); + } + return m; } // ====================================================================== -// Profiling: Stage 1 only (pre-pass + hist + cumsum + route/filter) +// Profiling: per-segment 256-bin histograms of Stage 1 remapped bins. // ====================================================================== -void topk_profile_stage1( +void topk_profile_histogram( const at::Tensor& x, const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, + at::Tensor& histograms, const int64_t eff_batch_size, - const int64_t topk_val, const int64_t reserved_bos, const int64_t reserved_eos, - const int64_t max_num_pages, const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale) + std::optional mapping_quantiles) { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_profile_stage1: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, - mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKStage1_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKStage1_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_stage1: unsupported dtype ", - x.scalar_type()); - } + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(histograms); + TORCH_CHECK(histograms.dim() == 2 && histograms.size(0) == eff_batch_size + && histograms.size(1) == 256, + "histograms must be [eff_batch_size, 256]"); + TORCH_CHECK(histograms.scalar_type() == at::ScalarType::Int, + "histograms must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + TopKProfileHistogram_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + TopKProfileHistogram_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + histograms.data_ptr(), + reserved_bos, reserved_eos, mapping); + } else { + TORCH_CHECK(false, "topk_profile_histogram: unsupported dtype ", x.scalar_type()); + } - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_stage1 kernel failed: ", ::cudaGetErrorString(result)); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_histogram kernel failed: ", ::cudaGetErrorString(result)); } // ====================================================================== -// Profiling: full pipeline + diagnostic counters +// Profiling: full pipeline + per-segment diagnostic counters. +// Adds extra global-memory writes — never use for latency measurement. // ====================================================================== void topk_profile_counters( const at::Tensor& x, @@ -1144,60 +528,53 @@ void topk_profile_counters( const int64_t mapping_mode, const double mapping_power, std::optional mapping_lut, - std::optional mapping_quantiles, - const bool mapping_noscale) + std::optional mapping_quantiles) { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_profile_counters: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - CHECK_CUDA(counters); - TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size - && counters.size(1) == NUM_TOPK_COUNTERS, - "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); - TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, - "counters must be int32"); - - auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles, - mapping_noscale, /*sample_stride=*/1, /*target_k=*/static_cast(topk_val)); - - dim3 nblks(eff_batch_size); - dim3 nthreads(kThreadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKCounters_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - counters.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKCounters_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - counters.data_ptr(), - topk_val, - reserved_bos, - reserved_eos, - mapping); - } else { - TORCH_CHECK(false, - "topk_profile_counters: unsupported dtype ", - x.scalar_type()); - } + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_profile_counters: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + CHECK_CUDA(counters); + TORCH_CHECK(counters.dim() == 2 && counters.size(0) == eff_batch_size + && counters.size(1) == NUM_TOPK_COUNTERS, + "counters must be [eff_batch_size, ", NUM_TOPK_COUNTERS, "]"); + TORCH_CHECK(counters.scalar_type() == at::ScalarType::Int, "counters must be int32"); + + auto mapping = build_mapping_params(mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + + dim3 nblks(eff_batch_size); + dim3 nthreads(kThreadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + if (x.scalar_type() == at::ScalarType::BFloat16) { + setup_kernel_smem_once, kSmem>(); + TopKProfileCounters_Kernel<__nv_bfloat16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else if (x.scalar_type() == at::ScalarType::Float) { + setup_kernel_smem_once, kSmem>(); + TopKProfileCounters_Kernel<<>>( + x.data_ptr(), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + counters.data_ptr(), + topk_val, reserved_bos, reserved_eos, mapping); + } else { + TORCH_CHECK(false, "topk_profile_counters: unsupported dtype ", x.scalar_type()); + } - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_profile_counters kernel failed: ", ::cudaGetErrorString(result)); } - diff --git a/examples/remap_function_bench.sh b/examples/remap_function_bench.sh new file mode 100755 index 00000000..7d56d57a --- /dev/null +++ b/examples/remap_function_bench.sh @@ -0,0 +1,238 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark +# +# Compares four kernel configurations for TopK page selection: +# 1. baseline — unmapped topk (topk_output_sglang) +# 2. fused remap + topk — topk_output_sglang_fused +# 3. remap only — topk_remap_only (standalone kernel) +# 4. unmapped topk on remapped — topk_output_sglang on the output +# buffer of step 3 +# +# Per configuration the script also reports the threshold-bin +# position, the threshold-bin size, and how many values are +# selected from the threshold bin (derived from +# topk_profile_counters — collected after all timing measurements, +# never interleaved with latency measurements). +# +# Pipeline: +# 1. Calibrate — run `calibrate_topk.py` on the chosen model to +# collect the REAL per-segment topk distribution +# (raw_histograms.npy). Skippable via +# --real-histograms /path/to/raw_histograms.npy. +# 2. Autotune — run `autotune_topk_mapping.py` on those real +# histograms and pick the per-mode hyperparameter +# with the LOWEST measured topk kernel latency. +# 3. Remap bench— run `bench_topk.py --remap-bench` with the +# autotune-selected per-mode hyperparameters. +# +# Argument layout mirrors run_distribution_analysis_new.sh. +# +# Usage: +# # Default (Qwen/Qwen3-1.7B, block_size=16): +# bash remap_function_bench.sh --gpu 5 +# +# # Larger model + larger page/block size: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --block-size 32 \ +# --seq-len 16384 --topk-val 512 \ +# --modes "0 3 6 7" +# +# # Reuse an existing calibration: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --real-histograms /path/to/calibration/raw_histograms.npy +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=65536 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes 1 (LUT_CDF) and 2 (Quantile) are evaluated only if calibration +# produces lut.npy / quantiles.npy. The shell script detects that below. +MAPPING_MODES="0 1 2 3 6 7 8 9 10 11 13" +# Fallback hparam used only if autotune is explicitly skipped. +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Empty by default — Step 1 will calibrate on the selected model. +# Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" + +# Validate seq_len: need pages/seg > topk_val (3 reserved pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Remap Function Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " GPU: ${GPU_ID}" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-distribution topk histograms ── +# calibrate_topk.py runs the model end-to-end with histogram profiling +# enabled and writes per-segment raw_histograms.npy. The histograms are +# aggregated over every layer and every decode/prefill step so the +# autotune in Step 2 sees the true attention-score distribution. +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# Calibration may have produced lut.npy / quantiles.npy for modes 1 and 2. +CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" +[ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" +[ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + +# ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── +# For every (mode, hparam) combo in the sweep grid, the autotune runs the +# fused remap+topk kernel on the real histogram and measures end-to-end +# kernel latency with CUDA events. The per-mode hparam with the lowest +# measured topk kernel latency wins. +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap benchmark (baseline / fused / remap / split) ── +echo "" +echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" +REMAP_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_remap_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${REMAP_JSON}" + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Remap Function Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution (per layer)" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench.json — per-config remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index fcc2ff1f..36d4cd4b 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -20,6 +20,10 @@ # bash run_distribution_analysis.sh --gpu 5 # bash run_distribution_analysis.sh --gpu 5 \ # --real-histograms /path/to/calibration_dir/raw_histograms.npy +# bash run_distribution_analysis.sh --gpu 5 --block-size 16 +# bash run_distribution_analysis.sh --watchdog-timeout 0 # disable calibrate watchdog (fork) +# Models (default: 1.7B + 4B). Override with repeated --model-name: +# bash run_distribution_analysis.sh --model-name Qwen/Qwen3-1.7B --model-name Qwen/Qwen3-4B # ============================================================ # Mapping functions: @@ -42,21 +46,34 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=4 -MODEL_NAME="Qwen/Qwen3-1.7B" +GPU_ID=7 +# Models to run (full pipeline per model). Override with one or more --model-name. +MODEL_NAMES=( "Qwen/Qwen3-1.7B" "Qwen/Qwen3-4B" ) +MODEL_NAMES_USER_SET=0 TOPK_VAL=30 MEM=0.7 ALGO="block_sparse_attention" RADIX_BITS=8 SAMPLE_STRIDE=1 SEQ_LEN=32768 +# KV page / block size (passed to benchmarks as --page-size) +BLOCK_SIZE=16 # The path to the raw_histograms.npy file (set to skip calibration) REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" REAL_HISTOGRAMS="" +HAS_WATCHDOG_TIMEOUT=0 +WATCHDOG_TIMEOUT="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in - --model-name) MODEL_NAME="$2"; shift 2 ;; + --model-name) + if [ "${MODEL_NAMES_USER_SET}" -eq 0 ]; then + MODEL_NAMES=() + MODEL_NAMES_USER_SET=1 + fi + MODEL_NAMES+=("$2") + shift 2 + ;; --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; @@ -65,14 +82,21 @@ while [[ $# -gt 0 ]]; do --radix-bits) RADIX_BITS="$2"; shift 2 ;; --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size) BLOCK_SIZE="$2"; shift 2 ;; + --watchdog-timeout) HAS_WATCHDOG_TIMEOUT=1; WATCHDOG_TIMEOUT="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" -# Validate seq_len: need pages/seg > topk_val (page_size=16, reserved=3 pages) -MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * 16 )) +if [ "${#MODEL_NAMES[@]}" -eq 0 ]; then + echo "ERROR: No models in MODEL_NAMES; pass at least one --model-name." + exit 1 +fi + +# Validate seq_len: need pages/seg > topk_val (reserved=3 pages + slack) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" @@ -82,140 +106,126 @@ fi RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -RUN_DIR="${RESULTS_DIR}/dist_analysis_${TIMESTAMP}" -mkdir -p "${RUN_DIR}" echo "============================================================" echo "Bucket Distribution Profiling Pipeline" -echo " Model: ${MODEL_NAME}" +echo " Models (${#MODEL_NAMES[@]}): ${MODEL_NAMES[*]}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Block size: ${BLOCK_SIZE} (--page-size in benchmarks)" echo " GPU: ${GPU_ID}" echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" echo " Sample stride: ${SAMPLE_STRIDE}" -echo " Real histograms: ${REAL_HISTOGRAMS:-}" -echo " Output: ${RUN_DIR}" -echo "============================================================" - -# ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── -if [ -n "${REAL_HISTOGRAMS}" ]; then - echo "" - echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" - REAL_HIST_PATH="${REAL_HISTOGRAMS}" -else - echo "" - echo ">>> Step 1: Calibrating — collecting real-inference histograms" - CALIBRATION_DIR="${RUN_DIR}/calibration" - mkdir -p "${CALIBRATION_DIR}" - python "${BENCH_DIR}/calibrate_topk.py" \ - --model-name "${MODEL_NAME}" \ - --topk-val "${TOPK_VAL}" \ - --mem "${MEM}" \ - --vortex-module-name "${ALGO}" \ - --output-dir "${CALIBRATION_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" - REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" - echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" -fi - -# ── Step 2: Auto-tune — sweep hyperparameters ────────────────── -echo "" -echo ">>> Step 2: Auto-tuning hyperparameters (modes 3, 6, 7, 9, 10)" - -AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" - -# Build autotune data source args -AUTOTUNE_EXTRA_ARGS=() -if [ -n "${REAL_HIST_PATH:-}" ]; then - AUTOTUNE_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") -fi - -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --topk-val "${TOPK_VAL}" \ - --batch-size 4 \ - --seq-len ${SEQ_LEN} \ - --num-kv-heads 2 \ - "${AUTOTUNE_EXTRA_ARGS[@]}" \ - --output-json "${AUTOTUNE_JSON}" \ - 2>&1 | tee "${RUN_DIR}/step2_autotune.log" - -echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" - -# ── Step 3: Histogram profiling (bucket_uniform + normal) ───── -echo "" -echo ">>> Step 3: Kernel-level histogram profiling (bucket_uniform + normal)" - -BENCH_JSON="${RUN_DIR}/bench_distribution.json" - -# Build optional args for bench_topk.py -BENCH_EXTRA_ARGS=() -if [ -n "${REAL_HIST_PATH:-}" ]; then - BENCH_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") -fi - -# Derive calibration directory from histogram path to find lut.npy / quantiles.npy -CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" -LUT_FILE="${CALIB_DIR}/lut.npy" -QUANTILES_FILE="${CALIB_DIR}/quantiles.npy" - -if [ -f "${LUT_FILE}" ]; then - BENCH_EXTRA_ARGS+=(--lut-path "${LUT_FILE}") - echo " Using LUT for mode 1: ${LUT_FILE}" -else - echo " WARNING: ${LUT_FILE} not found — mode 1 (LUT CDF) will be skipped" -fi -if [ -f "${QUANTILES_FILE}" ]; then - BENCH_EXTRA_ARGS+=(--quantiles-path "${QUANTILES_FILE}") - echo " Using quantiles for mode 2: ${QUANTILES_FILE}" +if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then + echo " Watchdog (cal): ${WATCHDOG_TIMEOUT}s (0 = off, vortex SGLang fork)" else - echo " WARNING: ${QUANTILES_FILE} not found — mode 2 (Quantile) will be skipped" + echo " Watchdog (cal): " fi +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Run id: ${TIMESTAMP}" +echo " Output root: ${RESULTS_DIR}/dist_analysis__${TIMESTAMP}/" +echo "============================================================" -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 \ - --seq-lens ${SEQ_LEN} \ - --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 8 \ - --distributions bucket_uniform normal \ - --histogram \ - --counters \ - "${BENCH_EXTRA_ARGS[@]}" \ - --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m1 sglang_m2 sglang_m3 sglang_m3_noscale sglang_m4 sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ - --radix-bits "${RADIX_BITS}" \ - --sample-stride "${SAMPLE_STRIDE}" \ - --repeat 20 \ - --output-json "${BENCH_JSON}" \ - 2>&1 | tee "${RUN_DIR}/step3_bench.log" - -echo ">>> Step 3: Done. Results saved to ${BENCH_JSON}" - -# ── Step 4: Analyze — comparison plots + tables ─────────────── -echo "" -echo ">>> Step 4: Generating distribution comparison plots + tables" +for MODEL_NAME in "${MODEL_NAMES[@]}"; do + MODEL_SLUG="${MODEL_NAME//\//_}" + RUN_DIR="${RESULTS_DIR}/dist_analysis_${MODEL_SLUG}_${TIMESTAMP}" + mkdir -p "${RUN_DIR}" -# Build optional args for analyze -ANALYZE_EXTRA_ARGS=() -if [ -n "${REAL_HIST_PATH:-}" ]; then - ANALYZE_EXTRA_ARGS+=(--real-histograms "${REAL_HIST_PATH}") -fi + echo "" + echo "############################ MODEL: ${MODEL_NAME} ############################" + echo " Output: ${RUN_DIR}" + + # ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── + if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" + else + echo "" + echo ">>> Step 1: Calibrating — collecting real-inference histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + CALIB_EXTRA_ARGS=() + if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then + CALIB_EXTRA_ARGS+=(--watchdog-timeout "${WATCHDOG_TIMEOUT}") + fi + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem "${MEM}" \ + --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + "${CALIB_EXTRA_ARGS[@]}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" + fi + + # Pick up lut.npy / quantiles.npy if calibration produced them. + CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" + LUT_PATH="" + Q_PATH="" + [ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" + [ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + [ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" + [ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + + # ── Step 2: Auto-tune — rank by fused-topk kernel latency ────── + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by fused-topk kernel latency" -python "${BENCH_DIR}/analyze_topk_distribution.py" \ - --bench-json "${BENCH_JSON}" \ - "${ANALYZE_EXTRA_ARGS[@]}" \ - --output-dir "${RUN_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step4_analyze.log" + AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" + AUTOTUNE_EXTRA=(--real-histograms "${REAL_HIST_PATH}") + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") -echo ">>> Step 4: Done." + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size 4 \ + --seq-len ${SEQ_LEN} \ + --page-size "${BLOCK_SIZE}" \ + --num-kv-heads 2 \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + + # ── Step 3: Remap benchmark with autotuned hparams ────────────── + echo "" + echo ">>> Step 3: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" + + BENCH_JSON="${RUN_DIR}/remap_bench.json" + BENCH_EXTRA=() + [ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") + + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes 4 \ + --num-kv-heads 8 \ + --seq-lens ${SEQ_LEN} \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions bucket_uniform normal \ + --mapping-modes 0 1 2 3 6 7 8 9 10 11 13 \ + --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --repeat 20 \ + --output-json "${BENCH_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_bench.log" + + echo ">>> Step 3: Done. Remap bench saved to ${BENCH_JSON}" +done # ── Summary ─────────────────────────────────────────────────── echo "" echo "============================================================" echo "Bucket Distribution Profiling Complete" -echo " All outputs in: ${RUN_DIR}/" -echo " autotune_results.json — hyperparameter sweep rankings" -echo " bench_distribution.json — raw benchmark data" -echo " distribution_comparison.png — bucket dist plots" -echo " bucket_counts.csv — per-bucket count table" -echo " step{1,2,3,4}_*.log — pipeline logs" +echo " Per-model outputs under ${RESULTS_DIR}/ (run id ${TIMESTAMP}):" +echo " dist_analysis__${TIMESTAMP}/" +echo " autotune_results.json, bench_distribution.json, plots, CSV, logs" echo "============================================================" diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index 65e4f413..ec726656 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -1,27 +1,31 @@ #!/usr/bin/env bash # ============================================================ -# Bucket Distribution Profiling Pipeline (modes 3, 6, 7 only) +# Bucket Distribution / Remap Latency Pipeline (parametric modes) # -# Tests only the parametric mapping modes with auto-tuning: -# Mode 3 (Power): y = sign(x) * |x|^p -# Mode 6 (Asinh): y = asinh(beta * x) -# Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) -# Mode 8 (Trunc8): bf16 upper-8-bit bucketing -# Mode 9 (Erf): y = erf(alpha * x) -# Mode 10 (Tanh): y = tanh(alpha * x) -# Mode 11 (Subtract): x - pivot (RadiK-style scatter) +# Tests the surviving parametric mapping modes after the lean +# refactor: +# Mode 3 (Power): y = sign(x) * |x|^p +# Mode 6 (Asinh): y = asinh(beta * x) +# Mode 7 (Log1p): y = sign(x) * log1p(alpha * |x|) +# Mode 9 (Erf): y = erf(alpha * x) +# Mode 10 (Tanh): y = tanh(alpha * x) +# Mode 13 (ExpStretch): y = exp(alpha * x) # -# Four steps: -# 1. Calibrate — collect real-data histograms -# (skippable via --real-histograms PATH) -# 2. Auto-tune — sweep hyperparameters on synthetic data -# 3. Bench — histogram profiling (bucket_uniform + normal) -# 4. Analyze — comparison plots + bucket count tables +# Pipeline: +# 1. Calibrate — collect real-distribution histograms from the +# chosen model (skippable via --real-histograms). +# 2. Autotune — rank per-mode hparams by measured fused-topk +# kernel latency (lowest wins). +# 3. Remap bench— bench_topk.py --remap-bench fed with the +# autotune JSON. Reports per-mode remap / topk / +# fused / baseline latencies and threshold stats. # # Usage: # bash run_distribution_analysis_new.sh --gpu 5 # bash run_distribution_analysis_new.sh --gpu 5 \ -# --real-histograms /path/to/calibration_dir/raw_histograms.npy +# --model-name Qwen/Qwen3-8B --block-size 32 +# bash run_distribution_analysis_new.sh --gpu 5 \ +# --real-histograms /path/to/raw_histograms.npy # ============================================================ set -euo pipefail @@ -34,65 +38,78 @@ MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=2048 MEM=0.7 ALGO="block_sparse_attention" -RADIX_BITS=8 -SAMPLE_STRIDE=1 SEQ_LEN=65536 -# The path to the raw_histograms.npy file (set to skip calibration) -REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" -# REAL_HISTOGRAMS="" -# ── Parse arguments ─────────────────────────────────────────── +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="bucket_uniform normal" +# LUT_CDF (1) / QUANTILE (2) are evaluated only when calibration produces +# lut.npy / quantiles.npy. 0 baseline is always included by --remap-bench. +MAPPING_MODES="1 2 3 6 7 8 9 10 11 13" +REPEAT=100 +WARMUP=20 +REAL_HISTOGRAMS="" + while [[ $# -gt 0 ]]; do case "$1" in - --model-name) MODEL_NAME="$2"; shift 2 ;; - --topk-val) TOPK_VAL="$2"; shift 2 ;; - --mem) MEM="$2"; shift 2 ;; - --gpu) GPU_ID="$2"; shift 2 ;; - --algo) ALGO="$2"; shift 2 ;; - --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; - --radix-bits) RADIX_BITS="$2"; shift 2 ;; - --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; - --seq-len) SEQ_LEN="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" -# Validate seq_len: need pages/seg > topk_val (page_size=16, reserved=3 pages) -MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * 16 )) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then - echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL}." - echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN}" exit 1 fi RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -RUN_DIR="${RESULTS_DIR}/dist_analysis_topk${TOPK_VAL}_${TIMESTAMP}" +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/dist_analysis_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" -echo "Bucket Distribution Profiling (modes 3, 6, 7, 8, 9, 10, 11)" +echo "Bucket Distribution / Remap Latency Pipeline (parametric modes)" echo " Model: ${MODEL_NAME}" echo " Algorithm: ${ALGO}" echo " TopK: ${TOPK_VAL}" -echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / 16 )) pages/seg)" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" echo " GPU: ${GPU_ID}" -echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" -echo " Sample stride: ${SAMPLE_STRIDE}" -echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" echo " Output: ${RUN_DIR}" echo "============================================================" -# ── Step 1: Calibrate — collect real-data histograms + LUT/quantiles ── +# ── Step 1: Calibrate ─────────────────────────────────────────── if [ -n "${REAL_HISTOGRAMS}" ]; then echo "" echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" REAL_HIST_PATH="${REAL_HISTOGRAMS}" else echo "" - echo ">>> Step 1: Calibrating — collecting real-inference histograms" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" CALIBRATION_DIR="${RUN_DIR}/calibration" mkdir -p "${CALIBRATION_DIR}" python "${BENCH_DIR}/calibrate_topk.py" \ @@ -100,74 +117,75 @@ else --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" fi -# ── Step 2: Auto-tune — sweep hyperparameters on synthetic data ───── -echo "" -echo ">>> Step 2: Auto-tuning hyperparameters (modes 3, 6, 7)" +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" +# ── Step 2: Autotune (latency-ranked) ─────────────────────────── +echo "" +echo ">>> Step 2: Auto-tuning hyperparameters by fused-topk kernel latency" AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" - +AUTOTUNE_EXTRA=() +[ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ - --batch-size 4 \ - --seq-len ${SEQ_LEN} \ - --num-kv-heads 8 \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --page-size "${BLOCK_SIZE}" \ --real-histograms "${REAL_HIST_PATH}" \ - --latency-rerank \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ --output-json "${AUTOTUNE_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2_autotune.log" - echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" -# ── Step 3: Histogram profiling (bucket_uniform + normal) ───── +# ── Step 3: Remap bench with autotuned hparams ────────────────── echo "" -echo ">>> Step 3: Kernel-level histogram profiling (modes 3, 6, 7)" - -BENCH_JSON="${RUN_DIR}/bench_distribution.json" - +echo ">>> Step 3: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" +BENCH_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") +[ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 \ - --seq-lens ${SEQ_LEN} \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 8 \ - --distributions bucket_uniform normal \ - --histogram \ - --counters \ - --real-histograms "${REAL_HIST_PATH}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels naive sglang_ori sglang_m0 sglang_scale sglang_m3 sglang_m3_noscale sglang_m6 sglang_m6_noscale sglang_m7 sglang_m7_noscale sglang_m8 sglang_m9 sglang_m9_noscale sglang_m10 sglang_m10_noscale sglang_m11 sglang_m13 sglang_m13_noscale sglang_m14 \ - --radix-bits "${RADIX_BITS}" \ - --sample-stride "${SAMPLE_STRIDE}" \ - --repeat 20 \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step3_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${BENCH_JSON}" -echo ">>> Step 3: Done. Results saved to ${BENCH_JSON}" - -# ── Step 4: Analyze — comparison plots + tables ─────────────── -echo "" -echo ">>> Step 4: Generating distribution comparison plots + tables" - -python "${BENCH_DIR}/analyze_topk_distribution.py" \ - --bench-json "${BENCH_JSON}" \ - --real-histograms "${REAL_HIST_PATH}" \ - --output-dir "${RUN_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step4_analyze.log" -echo ">>> Step 4: Done." - -# ── Summary ─────────────────────────────────────────────────── +# ── Summary ───────────────────────────────────────────────────── echo "" echo "============================================================" -echo "Bucket Distribution Profiling Complete (modes 3, 6, 7, 8, 9, 10, 11)" +echo "Bucket Distribution / Remap Latency Pipeline Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" echo " All outputs in: ${RUN_DIR}/" -echo " autotune_results.json — hyperparameter sweep rankings" -echo " bench_distribution.json — raw benchmark data" -echo " distribution_comparison.png — bucket dist plots" -echo " bucket_counts.csv — per-bucket count table" -echo " step{1,2,3,4}_*.log — pipeline logs" +echo " calibration/raw_histograms.npy — real topk distribution" +echo " autotune_results.json — latency-ranked hparams" +echo " remap_bench.json — remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" echo "============================================================" diff --git a/examples/test_topk.py b/examples/test_topk.py new file mode 100644 index 00000000..01edc7b4 --- /dev/null +++ b/examples/test_topk.py @@ -0,0 +1,118 @@ +import torch +import triton +# topk_output_sglang expects sparse_kv_indptr before dense_kv_indices (unlike topk_output). +from vortex_torch_C import topk_output_sglang as topk_output + +SEQ_LENS = [4096] +BATCH_SIZES = [256] + +K = 32 +RESERVE_BOS = 0 +RESERVE_EOS = 0 +DEVICE = "cuda" + + +def make_inputs(batch_size, seq_len, k, reserve_bos, reserve_eos, device="cuda"): + dense_kv_indptr = torch.arange( + 0, batch_size * seq_len + 1, seq_len, dtype=torch.int32, device=device + ) + + dense_kv_indices = torch.arange( + 0, batch_size * seq_len, dtype=torch.int32, device=device + ) + + scores = torch.randn( + batch_size * seq_len, dtype=torch.bfloat16, device=device + ) + + # ✅ Fixed CSR-style sparse indptr + sparse_kv_indptr = torch.arange( + 0, batch_size * k + 1, k, dtype=torch.int32, device=device + ) + + sparse_kv_indices = torch.empty( + batch_size * k, dtype=torch.int32, device=device + ) + + return ( + scores, + dense_kv_indptr, + dense_kv_indices, + sparse_kv_indptr, + sparse_kv_indices, + ) + + +def bench_one(batch_size, seq_len, k, reserve_bos, reserve_eos): + ( + scores, + dense_kv_indptr, + dense_kv_indices, + sparse_kv_indptr, + sparse_kv_indices, + ) = make_inputs( + batch_size=batch_size, + seq_len=seq_len, + k=k, + reserve_bos=reserve_bos, + reserve_eos=reserve_eos, + device=DEVICE, + ) + + def fn(): + topk_output( + scores, + dense_kv_indptr, + sparse_kv_indptr, + dense_kv_indices, + sparse_kv_indices, + batch_size, + k, + reserve_bos, + reserve_eos, + seq_len, + ) + + # warmup + for _ in range(10): + fn() + torch.cuda.synchronize() + + ms = triton.testing.do_bench( + fn, + warmup=100, + rep=1000, + return_mode="mean", + ) + return ms + + +def main(): + torch.cuda.init() + + results = {} + + for bs in BATCH_SIZES: + results[bs] = {} + for seq_len in SEQ_LENS: + ms = bench_one( + batch_size=bs, + seq_len=seq_len, + k=K, + reserve_bos=RESERVE_BOS, + reserve_eos=RESERVE_EOS, + ) + results[bs][seq_len] = ms + print(f"bs={bs:>3}, seq_len={seq_len:>4} -> {ms:.6f} ms") + + print("\nLatency table (ms):") + header = "bs\\seq".ljust(10) + "".join(f"{s:>12}" for s in SEQ_LENS) + print(header) + + for bs in BATCH_SIZES: + row = f"{bs:<10}" + "".join(f"{results[bs][s]:>12.4f}" for s in SEQ_LENS) + print(row) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/verify_algo.py b/examples/verify_algo.py index a1d1b6f3..a78f1e69 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -120,8 +120,6 @@ def verify_algos( topk_type: str = "naive", topk_mapping_mode: int = 0, topk_mapping_hparam: float = 0.5, -topk_mapping_lut_path: str = None, -topk_mapping_quantiles_path: str = None, disable_cuda_graph: bool = False, benchmark: str = "amc23", ): @@ -143,8 +141,6 @@ def verify_algos( vortex_topk_type=topk_type, vortex_topk_mapping_mode=topk_mapping_mode, vortex_topk_mapping_hparam=topk_mapping_hparam, - vortex_topk_mapping_lut_path=topk_mapping_lut_path, - vortex_topk_mapping_quantiles_path=topk_mapping_quantiles_path, ) tokenizer = AutoTokenizer.from_pretrained(model_name) if benchmark != "amc23" else None prompts, requests = _load_benchmark(benchmark, trials, tokenizer=tokenizer) @@ -300,15 +296,17 @@ def parse_args(): "--topk-type", type=str, default="naive", - choices=["naive", "sglang", "sglang_ori"], - help='TopK kernel type: "naive" for topk_output, "sglang" for topk_output_sglang, "sglang_ori" for original sglang baseline (default: "naive").', + choices=["naive", "sglang", "sglang_fused"], + help='TopK kernel type: "naive" (CUB radix), "sglang" (unmapped baseline), "sglang_fused" (fused remap + topk). Default: "naive".', ) parser.add_argument( "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14], - help='TopK mapping mode: 0=none, 1=lut_cdf, 2=quantile, 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, 9=erf, 10=tanh, 11=subtract, 12=adaptive_tail_window, 13=exp_stretch, 14=topk_window (default: 0).', + choices=[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13], + help='TopK mapping mode for sglang_fused: 0=none, 1=lut_cdf (calibrated), ' + '2=quantile (calibrated), 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, ' + '9=erf, 10=tanh, 11=subtract, 13=exp_stretch (default: 0).', ) parser.add_argument( @@ -319,20 +317,6 @@ def parse_args(): help='Hyperparameter for parametric modes: power exponent (mode 3), beta (mode 6), alpha (mode 7/9/10/13), rho (mode 12/14). Default: 0.5.', ) - parser.add_argument( - "--topk-mapping-lut-path", - type=str, - default=None, - help="Path to .npy file with uint8[256] LUT for topk mapping mode 1.", - ) - - parser.add_argument( - "--topk-mapping-quantiles-path", - type=str, - default=None, - help="Path to .npy file with float32[256] quantiles for topk mapping mode 2.", - ) - parser.add_argument( "--benchmark", type=str, @@ -366,8 +350,6 @@ def parse_args(): topk_type=args.topk_type, topk_mapping_mode=args.topk_mapping_mode, topk_mapping_hparam=args.topk_mapping_hparam, - topk_mapping_lut_path=args.topk_mapping_lut_path, - topk_mapping_quantiles_path=args.topk_mapping_quantiles_path, benchmark=bench_name, ) summary["benchmark"] = bench_name diff --git a/examples/verify_algo.sh b/examples/verify_algo.sh index ddcd905e..7a96d1e7 100644 --- a/examples/verify_algo.sh +++ b/examples/verify_algo.sh @@ -23,9 +23,4 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) --topk-type naive \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" - done - - TORCH_CUDA_ARCH_LIST="12.0" \ - MAX_JOBS=64 \ - pip install -e . --no-build-isolation \ - -Ccmake.args="-DENABLE_BELOW_SM90=OFF" \ No newline at end of file + done \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index 9a9f482e..711a0f77 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -1,26 +1,46 @@ #!/usr/bin/env bash +# ============================================================ +# E2E accuracy comparison: naive baseline + unmapped sglang + +# every surviving parametric mapping mode (3, 4, 6, 7, 9, 10, 13) +# with per-mode hyperparameters picked by autotune_topk_mapping.py +# (ranked by measured fused-topk kernel latency, lowest wins). +# +# Surviving mapping modes after the lean refactor: +# 0: None — unmapped baseline +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 13: ExpStretch — y = exp(alpha * x) +# ============================================================ set -e -# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use -# Mapping functions: -# 0: None — original fp16 bit-pattern bucketing -# 1: LUT CDF — LUT-based CDF equalization (calibrated) -# 2: Quantile — piecewise-linear quantile mapping (calibrated) -# 3: Power — y = sign(x) * |x|^p -# 4: Log — y = sign(x) * log(|x| + 1) -# 5: Index Cache — reuse previous layer's indices -# 6: Asinh — y = asinh(beta * x) -# 7: Log1p — y = sign(x) * log1p(alpha * |x|) -# 8: Trunc8 — bf16 upper-8-bit bucketing -# 9: Erf — y = erf(alpha * x) -# 10: Tanh — y = tanh(alpha * x) -# 11: Subtract — x - pivot (RadiK-style scatter) + +# ── Defaults ────────────────────────────────────────────────── GPU_ID=0 BENCHMARKS="amc23" +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=30 +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=2 +SEQ_LEN=32768 +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 while [[ $# -gt 0 ]]; do case "$1" in - --gpu) GPU_ID="$2"; shift 2 ;; - --benchmark) BENCHMARKS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done @@ -30,273 +50,151 @@ export CUDA_VISIBLE_DEVICES="${GPU_ID}" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" -sparse_algos=( - "block_sparse_attention" -) +sparse_algos=( "block_sparse_attention" ) BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') -RESULTS_DIR="results/${BENCH_LABEL}" +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RESULTS_DIR="results/${MODEL_SLUG}_${BENCH_LABEL}" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -# Set this to an existing calibration directory to skip re-running calibration. -# It must contain lut.npy and quantiles.npy (output of calibrate_topk.py). -CALIBRATION_DIR="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration" + # ============================================================ -# Baseline: naive topk (mode 0) +# Baseline: naive topk # ============================================================ for algo in "${sparse_algos[@]}"; do OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_naive_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type naive --topk-mapping-mode 0" - echo ">>> Saving results to ${OUTFILE}" + echo ">>> naive topk algo=${algo}" { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val "${TOPK_VAL}" \ --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ + --model-name "${MODEL_NAME}" \ --topk-type naive \ - --topk-mapping-mode 0 \ --benchmark ${BENCHMARKS} \ --mem 0.7 ; } \ 2>&1 | tee "${OUTFILE}" done # ============================================================ -# Calibration: collect histograms for LUT/quantile generation -# Skipped if CALIBRATION_DIR already has lut.npy + quantiles.npy +# Calibrate (optional) — real-distribution histograms # ============================================================ -if [ -f "${CALIBRATION_DIR}/lut.npy" ] && [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then - echo ">>> Calibration SKIPPED (using existing ${CALIBRATION_DIR})" -else - CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" - for algo in "${sparse_algos[@]}"; do - echo ">>> Calibrating for ${algo}..." - python "${BENCH_DIR}/calibrate_topk.py" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-val 30 \ - --mem 0.7 \ - --vortex-module-name "${algo}" \ - --output-dir "${CALIBRATION_DIR}" \ - 2>&1 | tee "${RESULTS_DIR}/calibration_${algo}_${TIMESTAMP}.log" - done +if [ -z "${REAL_HISTOGRAMS}" ]; then + CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + for algo in "${sparse_algos[@]}"; do + echo ">>> Calibrating ${MODEL_NAME} for ${algo}..." + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem 0.7 \ + --vortex-module-name "${algo}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibration_${algo}_${TIMESTAMP}.log" + done + REAL_HISTOGRAMS="${CALIBRATION_DIR}/raw_histograms.npy" fi -# ============================================================ -# Auto-tune: find best hyperparameters per mode -# Uses topk_profile_histogram kernel on real calibration data -# ============================================================ -REAL_HISTOGRAMS="${CALIBRATION_DIR}/raw_histograms.npy" -if [ -f "${REAL_HISTOGRAMS}" ]; then - echo "============================================================" - echo "Auto-tuning hyperparameters (real calibration data)" - echo "============================================================" - AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" - PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --topk-val 30 \ - --batch-size 4 \ - --seq-len 32768 \ - --num-kv-heads 2 \ - --real-histograms "${REAL_HISTOGRAMS}" \ - --output-json "${AUTOTUNE_JSON}" \ - 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" - echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" - echo "" +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HISTOGRAMS}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" + +# ============================================================ +# Auto-tune — rank by fused-topk kernel latency +# ============================================================ +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +if [ "${SKIP_AUTOTUNE}" -eq 0 ]; then + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + if [ -f "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Auto-tuning hyperparameters (real distribution, latency-ranked)" + echo "============================================================" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HISTOGRAMS}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" + else + echo ">>> WARNING: ${REAL_HISTOGRAMS} not found — autotune will use synthetic data" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val "${TOPK_VAL}" \ + --batch-size "${BATCH_SIZE}" \ + --seq-len "${SEQ_LEN}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --page-size "${BLOCK_SIZE}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + fi +fi - # Extract best per-mode hyperparameters from autotune JSON - eval "$(python3 -c " +# Extract best per-mode hparam (ranked by kernel latency, lowest wins). +eval "$(python3 -c " import json, sys data = json.load(open(sys.argv[1])) best = {} for r in data: - m = r.get('mode') - if m in (3, 6, 7, 9, 10): - if m not in best or r['gini'] < best[m]['gini']: - best[m] = r -for m in (3, 6, 7, 9, 10): - print(f'BEST_HPARAM_{m}={best[m][\"param\"]}' if m in best else f'BEST_HPARAM_{m}=0.5') + m = r.get('mode'); lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + print(f'BEST_HPARAM_{m}={best.get(m, {}).get(\"param\", 0.5)}') " "${AUTOTUNE_JSON}")" - echo ">>> Autotuned best powers: mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7} mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10}" - echo "" -else - echo ">>> WARNING: ${REAL_HISTOGRAMS} not found, using default power=0.5 for all modes" - BEST_HPARAM_3=0.5 - BEST_HPARAM_6=0.5 - BEST_HPARAM_7=0.5 - BEST_HPARAM_9=0.5 - BEST_HPARAM_10=0.5 -fi - -# ============================================================ -# Mode 1: LUT CDF with calibrated LUT -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_1_calibrated_${TIMESTAMP}.log" - echo ">>> Running mode 1 (LUT CDF) with calibrated LUT for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 1 \ - --topk-mapping-lut-path "${CALIBRATION_DIR}/lut.npy" \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 2: Quantile with calibrated quantiles -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_2_calibrated_${TIMESTAMP}.log" - echo ">>> Running mode 2 (quantile) with calibrated quantiles for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 2 \ - --topk-mapping-quantiles-path "${CALIBRATION_DIR}/quantiles.npy" \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# sglang topk: non-parametric modes (0, 4, 8, 11) -# ============================================================ -for algo in "${sparse_algos[@]}"; do - for topk_mapping_mode in 0 4 8 11; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_${topk_mapping_mode}_${TIMESTAMP}.log" - echo ">>> Running verify_algo.py with --vortex-module-name ${algo} --topk-type sglang --topk-mapping-mode ${topk_mapping_mode}" - echo ">>> Saving results to ${OUTFILE}" +echo ">>> Autotuned hparams (lowest fused-topk latency):" +echo " mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7}" +echo " mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode11=${BEST_HPARAM_11} mode13=${BEST_HPARAM_13}" +echo "" +run_mapped() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + for algo in "${sparse_algos[@]}"; do + local out="${RESULTS_DIR}/topk_mapping_${algo}_${label}_${TIMESTAMP}.log" + echo ">>> ${label} algo=${algo}" + local extra=() + if [ "${mode}" -eq 0 ]; then + extra+=(--topk-type sglang) + else + extra+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi { time python verify_algo.py \ --trials 8 \ - --topk-val 30 \ + --topk-val "${TOPK_VAL}" \ --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode ${topk_mapping_mode} \ + --model-name "${MODEL_NAME}" \ --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" + --mem 0.7 \ + "${extra[@]}" ; } \ + 2>&1 | tee "${out}" done -done +} + +run_mapped 0 0.5 "sglang_m0" +run_mapped 3 "${BEST_HPARAM_3}" "sglang_m3_p${BEST_HPARAM_3}" +run_mapped 4 0.5 "sglang_m4" +run_mapped 6 "${BEST_HPARAM_6}" "sglang_m6_beta${BEST_HPARAM_6}" +run_mapped 7 "${BEST_HPARAM_7}" "sglang_m7_alpha${BEST_HPARAM_7}" +run_mapped 8 0.5 "sglang_m8" +run_mapped 9 "${BEST_HPARAM_9}" "sglang_m9_alpha${BEST_HPARAM_9}" +run_mapped 10 "${BEST_HPARAM_10}" "sglang_m10_alpha${BEST_HPARAM_10}" +run_mapped 11 "${BEST_HPARAM_11}" "sglang_m11_pivot${BEST_HPARAM_11}" +run_mapped 13 "${BEST_HPARAM_13}" "sglang_m13_alpha${BEST_HPARAM_13}" -# ============================================================ -# Mode 3: power — autotuned best p -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_HPARAM_3}_${TIMESTAMP}.log" - echo ">>> Running mode 3 (power) p=${BEST_HPARAM_3} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 3 \ - --topk-mapping-hparam ${BEST_HPARAM_3} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 6: asinh — autotuned best beta -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_HPARAM_6}_${TIMESTAMP}.log" - echo ">>> Running mode 6 (asinh) beta=${BEST_HPARAM_6} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-hparam ${BEST_HPARAM_6} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 7: log1p — autotuned best alpha -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_HPARAM_7}_${TIMESTAMP}.log" - echo ">>> Running mode 7 (log1p) alpha=${BEST_HPARAM_7} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-hparam ${BEST_HPARAM_7} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 9: erf — autotuned best alpha -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_HPARAM_9}_${TIMESTAMP}.log" - echo ">>> Running mode 9 (erf) alpha=${BEST_HPARAM_9} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 9 \ - --topk-mapping-hparam ${BEST_HPARAM_9} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Mode 10: tanh — autotuned best alpha -# ============================================================ -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_HPARAM_10}_${TIMESTAMP}.log" - echo ">>> Running mode 10 (tanh) alpha=${BEST_HPARAM_10} (autotuned) for ${algo}" - echo ">>> Saving results to ${OUTFILE}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val 30 \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 10 \ - --topk-mapping-hparam ${BEST_HPARAM_10} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Counter profiling: collect COUNTER_NUM_EQUAL for all modes -# ============================================================ echo "" echo "============================================================" -echo "Counter profiling: COUNTER_NUM_EQUAL per mode (topk=30)" +echo "All runs complete. Results in ${RESULTS_DIR}/" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Auto-tune: ${AUTOTUNE_JSON}" echo "============================================================" -COUNTER_JSON="${RESULTS_DIR}/counters_${TIMESTAMP}.json" -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 \ - --seq-lens 4096 \ - --topk-vals 30 \ - --num-kv-heads 2 \ - --distributions normal \ - --counters \ - --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 \ - --repeat 5 \ - --output-json "${COUNTER_JSON}" \ - 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" -echo ">>> Counters saved to ${COUNTER_JSON}" \ No newline at end of file diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index 6848e1ea..9116b722 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -1,370 +1,217 @@ #!/usr/bin/env bash +# ============================================================ +# E2E accuracy sweep over the surviving parametric mapping modes. +# Each mode runs verify_algo.py with the per-mode hyperparameter +# that autotune_topk_mapping.py picked as having the lowest +# measured fused-topk-kernel latency. +# +# Mapping modes (after the lean refactor): +# 0: None — unmapped baseline (no remap) +# 3: Power — y = sign(x) * |x|^p +# 4: Log — y = sign(x) * log(|x| + 1) [no knob] +# 6: Asinh — y = asinh(beta * x) +# 7: Log1p — y = sign(x) * log1p(alpha * |x|) +# 9: Erf — y = erf(alpha * x) +# 10: Tanh — y = tanh(alpha * x) +# 13: ExpStretch — y = exp(alpha * x) +# ============================================================ set -e -# use CUDA_VISIBLE_DEVICES to set the GPU id you want to use -# Mapping functions: -# 0: None — original fp16 bit-pattern bucketing -# 1: LUT CDF — LUT-based CDF equalization (calibrated) -# 2: Quantile — piecewise-linear quantile mapping (calibrated) -# 3: Power — y = sign(x) * |x|^p -# 4: Log — y = sign(x) * log(|x| + 1) -# 5: Index Cache — reuse previous layer's indices -# 6: Asinh — y = asinh(beta * x) -# 7: Log1p — y = sign(x) * log1p(alpha * |x|) -# 8: Trunc8 — bf16 upper-8-bit bucketing -# 9: Erf — y = erf(alpha * x) -# 10: Tanh — y = tanh(alpha * x) -# 11: Subtract — x - pivot (RadiK-style scatter) SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── GPU_ID=5 TOPK_VAL=30 -BENCHMARKS="amc23" # space-separated list, e.g. "amc23 aime24" +BENCHMARKS="amc23" +MODEL_NAME="Qwen/Qwen3-1.7B" +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=2 +SEQ_LEN=32768 +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 -# ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in - --topk-val) TOPK_VAL="$2"; shift 2 ;; - --gpu) GPU_ID="$2"; shift 2 ;; - --benchmark) BENCHMARKS="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --benchmark) BENCHMARKS="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" -sparse_algos=( - "block_sparse_attention" -) - -# Path to real-data histograms from calibration (for auto-tuning) -REAL_HISTOGRAMS="/data/datasets/xinrui/My_Projects/vortex_torch/examples/calibration/raw_histograms.npy" +sparse_algos=( "block_sparse_attention" ) BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') -RESULTS_DIR="results/topk${TOPK_VAL}_${BENCH_LABEL}" +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RESULTS_DIR="results/topk_mapping_${MODEL_SLUG}_topk${TOPK_VAL}_${BENCH_LABEL}" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) # ============================================================ -# Step 0: Auto-tune — find best hyperparameters per mode -# Uses topk_profile_histogram kernel on synthetic data (fast, no model) +# Step 0: Calibrate (optional) — real-distribution histograms # ============================================================ -echo "============================================================" -echo "Step 0: Auto-tuning hyperparameters (synthetic data)" -echo "============================================================" -AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --topk-val ${TOPK_VAL} \ - --batch-size 4 \ - --seq-len 32768 \ - --num-kv-heads 2 \ - --real-histograms "${REAL_HISTOGRAMS}" \ - --output-json "${AUTOTUNE_JSON}" \ - 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" -echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" -echo "" +if [ -z "${REAL_HISTOGRAMS}" ]; then + echo "============================================================" + echo "Step 0: Calibrating ${MODEL_NAME} for real-distribution histograms" + echo "============================================================" + CAL_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + mkdir -p "${CAL_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --mem 0.7 \ + --vortex-module-name "${sparse_algos[0]}" \ + --page-size "${BLOCK_SIZE}" \ + --output-dir "${CAL_DIR}" \ + 2>&1 | tee "${RESULTS_DIR}/calibrate_${TIMESTAMP}.log" + REAL_HISTOGRAMS="${CAL_DIR}/raw_histograms.npy" +fi + +# Pick up lut.npy / quantiles.npy if calibration produced them. +CALIB_DIR="$(dirname "${REAL_HISTOGRAMS}")" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" +[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" # ============================================================ -# Extract best per-mode hyperparameters from autotune JSON +# Step 1: Auto-tune — rank by profiled fused-topk kernel latency # ============================================================ +AUTOTUNE_JSON="${RESULTS_DIR}/autotune_${TIMESTAMP}.json" +if [ "${SKIP_AUTOTUNE}" -eq 0 ]; then + echo "============================================================" + echo "Step 1: Auto-tuning hyperparameters by fused-topk kernel latency" + echo "============================================================" + AUTOTUNE_EXTRA=() + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --topk-val ${TOPK_VAL} \ + --batch-size ${BATCH_SIZE} \ + --seq-len ${SEQ_LEN} \ + --num-kv-heads ${NUM_KV_HEADS} \ + --page-size ${BLOCK_SIZE} \ + --real-histograms "${REAL_HISTOGRAMS}" \ + "${AUTOTUNE_EXTRA[@]}" \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RESULTS_DIR}/autotune_${TIMESTAMP}.log" + echo ">>> Auto-tune results saved to ${AUTOTUNE_JSON}" +fi + +# Extract best per-mode hparam (ranked by measured kernel latency, lowest wins) eval "$(python3 -c " import json, sys data = json.load(open(sys.argv[1])) best = {} for r in data: m = r.get('mode') - if m in (3, 6, 7, 9, 10, 13, 14): - if m not in best or r['gini'] < best[m]['gini']: - best[m] = r -for m in (3, 6, 7, 9, 10, 13, 14): - print(f'BEST_HPARAM_{m}={best[m][\"param\"]}' if m in best else f'BEST_HPARAM_{m}=0.5') + lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + v = best.get(m, {}).get('param', 0.5) + print(f'BEST_HPARAM_{m}={v}') " "${AUTOTUNE_JSON}")" -echo ">>> Autotuned best powers: mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7} mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode13=${BEST_HPARAM_13} mode14=${BEST_HPARAM_14}" +echo ">>> Autotuned hparams (lowest topk kernel latency):" +echo " mode3=${BEST_HPARAM_3} mode6=${BEST_HPARAM_6} mode7=${BEST_HPARAM_7}" +echo " mode9=${BEST_HPARAM_9} mode10=${BEST_HPARAM_10} mode11=${BEST_HPARAM_11} mode13=${BEST_HPARAM_13}" echo "" -# ============================================================ -# Baseline: Original sglang kernel (no remap) -# ============================================================ -echo "============================================================" -echo "Baseline: sglang_ori (no remap)" -echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_ori_${TIMESTAMP}.log" - echo ">>> sglang_ori algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang_ori \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + for algo in "${sparse_algos[@]}"; do + local out="${RESULTS_DIR}/topk_mapping_${algo}_${label}_${TIMESTAMP}.log" + echo ">>> ${label} algo=${algo}" + local extra_args=() + if [ "${mode}" -eq 0 ]; then + extra_args+=(--topk-type sglang) + else + extra_args+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + { time python verify_algo.py \ + --trials 8 \ + --topk-val "${TOPK_VAL}" \ + --vortex-module-name "${algo}" \ + --model-name "${MODEL_NAME}" \ + --benchmark ${BENCHMARKS} \ + --mem 0.7 \ + "${extra_args[@]}" ; } \ + 2>&1 | tee "${out}" + done +} -# ============================================================ -# Step 1: Mode 3 (power) — autotuned best p -# ============================================================ echo "============================================================" -echo "Step 1: Mode 3 (power) — p=${BEST_HPARAM_3} (autotuned)" +echo "Baseline: sglang (no remap)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_3_p${BEST_HPARAM_3}_${TIMESTAMP}.log" - echo ">>> Mode 3 (power) p=${BEST_HPARAM_3} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 3 \ - --topk-mapping-hparam ${BEST_HPARAM_3} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 0 0.5 "sglang_m0" -# ============================================================ -# Step 2: Mode 6 (asinh) — autotuned best beta -# ============================================================ echo "============================================================" -echo "Step 2: Mode 6 (asinh) — beta=${BEST_HPARAM_6} (autotuned)" +echo "Mode 3 (power) — p=${BEST_HPARAM_3} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_6_beta${BEST_HPARAM_6}_${TIMESTAMP}.log" - echo ">>> Mode 6 (asinh) beta=${BEST_HPARAM_6} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-hparam ${BEST_HPARAM_6} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 3 "${BEST_HPARAM_3}" "sglang_m3_p${BEST_HPARAM_3}" -# ============================================================ -# Step 3: Mode 7 (log1p) — autotuned best alpha -# ============================================================ echo "============================================================" -echo "Step 3: Mode 7 (log1p) — alpha=${BEST_HPARAM_7} (autotuned)" +echo "Mode 4 (log)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_7_alpha${BEST_HPARAM_7}_${TIMESTAMP}.log" - echo ">>> Mode 7 (log1p) alpha=${BEST_HPARAM_7} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-hparam ${BEST_HPARAM_7} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 4 0.5 "sglang_m4" -# ============================================================ -# Step 4: Mode 8 (trunc8) — fixed parameter -# ============================================================ echo "============================================================" -echo "Step 4: Mode 8 (trunc8)" +echo "Mode 6 (asinh) — beta=${BEST_HPARAM_6} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_8_${TIMESTAMP}.log" - echo ">>> Mode 8 (trunc8) algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 8 \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 6 "${BEST_HPARAM_6}" "sglang_m6_beta${BEST_HPARAM_6}" -# ============================================================ -# Step 5: Mode 9 (erf) — autotuned best alpha -# ============================================================ echo "============================================================" -echo "Step 5: Mode 9 (erf) — alpha=${BEST_HPARAM_9} (autotuned)" +echo "Mode 7 (log1p) — alpha=${BEST_HPARAM_7} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_9_alpha${BEST_HPARAM_9}_${TIMESTAMP}.log" - echo ">>> Mode 9 (erf) alpha=${BEST_HPARAM_9} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 9 \ - --topk-mapping-hparam ${BEST_HPARAM_9} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 7 "${BEST_HPARAM_7}" "sglang_m7_alpha${BEST_HPARAM_7}" -# ============================================================ -# Step 6: Mode 10 (tanh) — autotuned best alpha -# ============================================================ echo "============================================================" -echo "Step 6: Mode 10 (tanh) — alpha=${BEST_HPARAM_10} (autotuned)" +echo "Mode 9 (erf) — alpha=${BEST_HPARAM_9} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_10_alpha${BEST_HPARAM_10}_${TIMESTAMP}.log" - echo ">>> Mode 10 (tanh) alpha=${BEST_HPARAM_10} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 10 \ - --topk-mapping-hparam ${BEST_HPARAM_10} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 9 "${BEST_HPARAM_9}" "sglang_m9_alpha${BEST_HPARAM_9}" -# ============================================================ -# Step 7: Mode 11 (subtract) — fixed parameter -# ============================================================ echo "============================================================" -echo "Step 7: Mode 11 (subtract)" +echo "Mode 10 (tanh) — alpha=${BEST_HPARAM_10} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_11_${TIMESTAMP}.log" - echo ">>> Mode 11 (subtract) algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 11 \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 10 "${BEST_HPARAM_10}" "sglang_m10_alpha${BEST_HPARAM_10}" -# ============================================================ -# Step 8: Mode 12 (adaptive_tail_window), rho=4.0 -# ============================================================ -echo "" echo "============================================================" -echo "Step 8: Mode 12 (adaptive_tail_window), rho=4.0" +echo "Mode 8 (trunc8)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_12_${TIMESTAMP}.log" - echo ">>> Mode 12 (adaptive_tail_window) algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 12 \ - --topk-mapping-hparam 4.0 \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 8 0.5 "sglang_m8" -# ============================================================ -# Step 9: Mode 13 (exp_stretch) — autotuned best alpha -# ============================================================ -echo "" -echo "============================================================" -echo "Step 9: Mode 13 (exp_stretch) — alpha=${BEST_HPARAM_13} (autotuned)" -echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_13_alpha${BEST_HPARAM_13}_${TIMESTAMP}.log" - echo ">>> Mode 13 (exp_stretch) alpha=${BEST_HPARAM_13} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 13 \ - --topk-mapping-hparam ${BEST_HPARAM_13} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done - -# ============================================================ -# Step 10: Mode 14 (topk_window) — autotuned best rho -# ============================================================ -echo "" echo "============================================================" -echo "Step 10: Mode 14 (topk_window) — rho=${BEST_HPARAM_14} (autotuned)" +echo "Mode 11 (subtract) — pivot=${BEST_HPARAM_11} (autotuned)" echo "============================================================" -for algo in "${sparse_algos[@]}"; do - OUTFILE="${RESULTS_DIR}/topk_mapping_${algo}_sglang_14_rho${BEST_HPARAM_14}_${TIMESTAMP}.log" - echo ">>> Mode 14 (topk_window) rho=${BEST_HPARAM_14} algo=${algo}" - { time python verify_algo.py \ - --trials 8 \ - --topk-val ${TOPK_VAL} \ - --vortex-module-name "${algo}" \ - --model-name Qwen/Qwen3-1.7B \ - --topk-type sglang \ - --topk-mapping-mode 14 \ - --topk-mapping-hparam ${BEST_HPARAM_14} \ - --benchmark ${BENCHMARKS} \ - --mem 0.7 ; } \ - 2>&1 | tee "${OUTFILE}" -done +run_verify 11 "${BEST_HPARAM_11}" "sglang_m11_pivot${BEST_HPARAM_11}" -# ============================================================ -# Counter profiling: collect COUNTER_NUM_EQUAL for all modes -# (single extra kernel call per mode, no overhead on accuracy runs) -# ============================================================ -echo "" echo "============================================================" -echo "Counter profiling: COUNTER_NUM_EQUAL per mode (topk=${TOPK_VAL})" +echo "Mode 13 (exp_stretch) — alpha=${BEST_HPARAM_13} (autotuned)" echo "============================================================" -COUNTER_JSON="${RESULTS_DIR}/counters_${TIMESTAMP}.json" -PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 \ - --seq-lens 4096 \ - --topk-vals ${TOPK_VAL} \ - --num-kv-heads 2 \ - --distributions normal \ - --counters \ - --real-histograms "${REAL_HISTOGRAMS}" \ - --autotune-json "${AUTOTUNE_JSON}" \ - --filter-kernels sglang_ori sglang_m0 sglang_m3 sglang_m6 sglang_m7 sglang_m8 sglang_m9 sglang_m10 sglang_m11 sglang_m13 sglang_m14 \ - --mapping-hparam-13 ${BEST_HPARAM_13} --mapping-hparam-14 ${BEST_HPARAM_14} \ - --repeat 5 \ - --output-json "${COUNTER_JSON}" \ - 2>&1 | tee "${RESULTS_DIR}/counters_${TIMESTAMP}.log" -echo ">>> Counters saved to ${COUNTER_JSON}" +run_verify 13 "${BEST_HPARAM_13}" "sglang_m13_alpha${BEST_HPARAM_13}" -# ============================================================ -# Summary -# ============================================================ echo "" echo "============================================================" echo "All runs complete. Results in ${RESULTS_DIR}/" -echo " Auto-tune: ${AUTOTUNE_JSON}" -echo " Counters: ${COUNTER_JSON}" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Auto-tune: ${AUTOTUNE_JSON}" echo " Mode 3 (power): p = ${BEST_HPARAM_3} (autotuned)" echo " Mode 6 (asinh): beta = ${BEST_HPARAM_6} (autotuned)" echo " Mode 7 (log1p): alpha = ${BEST_HPARAM_7} (autotuned)" -echo " Mode 8 (trunc8): (fixed)" echo " Mode 9 (erf): alpha = ${BEST_HPARAM_9} (autotuned)" echo " Mode 10 (tanh): alpha = ${BEST_HPARAM_10} (autotuned)" -echo " Mode 11 (subtract): (fixed)" -echo " Mode 12 (tail_win): rho = 4.0" echo " Mode 13 (exp_stretch):alpha = ${BEST_HPARAM_13} (autotuned)" -echo " Mode 14 (topk_window):rho = ${BEST_HPARAM_14} (autotuned)" echo "============================================================" diff --git a/examples/verify_sparse_backends.sh b/examples/verify_external_backends.sh similarity index 100% rename from examples/verify_sparse_backends.sh rename to examples/verify_external_backends.sh diff --git a/vortex_torch/indexer/context.py b/vortex_torch/indexer/context.py index 8142fbc6..17dea66c 100644 --- a/vortex_torch/indexer/context.py +++ b/vortex_torch/indexer/context.py @@ -1,6 +1,5 @@ from __future__ import annotations from typing import Any, Final, Union -import numpy as np import torch from ..abs import ContextBase from ..utils import UNSET, Mode @@ -24,8 +23,7 @@ class Context(ContextBase): "num_sms", "page_size", "max_num_pages", "max_num_pages_per_request", # misc "indexer_dtype", "topk_val", "page_reserved_bos", "page_reserved_eos", "topk_type", - "topk_mapping_mode", "topk_mapping_power", "topk_mapping_lut", "topk_mapping_quantiles", - "topk_mapping_noscale", + "topk_mapping_mode", "topk_mapping_power", "topk_histogram_enabled", # auxilary memory in graph @@ -72,13 +70,10 @@ class Context(ContextBase): topk_val: int #: Top-K value used in pruning or selection. page_reserved_bos: int #: Reserved page count for BOS (begin-of-sequence). page_reserved_eos: int #: Reserved page count for EOS (end-of-sequence). - topk_type: str #: TopK kernel type: "naive" or "sglang". - topk_mapping_mode: int #: TopK mapping mode (0=none, 1=lut, 2=quantile, 3=power, 4=log). - topk_mapping_power: float #: Power exponent for mapping mode 3. - topk_mapping_lut: object #: Optional uint8[256] LUT tensor for mapping mode 1. - topk_mapping_quantiles: object #: Optional float32[256] quantiles tensor for mapping mode 2. - topk_mapping_noscale: bool #: Skip auto-range linear scaling, use fp16 bucketing on f(x) (default False). - topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). + topk_type: str #: TopK kernel type: "naive", "sglang" (unmapped) or "sglang_fused" (remap+topk). + topk_mapping_mode: int #: TopK mapping mode for sglang_fused (0=none, 3=power, 4=log, 6=asinh, 7=log1p, 9=erf, 10=tanh, 13=exp_stretch). + topk_mapping_power: float #: Hyperparameter (p / alpha / beta) for the active mapping mode. + topk_histogram_enabled: bool #: Enable histogram profiling during inference (default False). # --- auxiliary --- _aux_total_bytes: int #: Accumulated auxiliary memory in bytes. @@ -158,22 +153,10 @@ def create(self, parent: Any, model_runner: Any, *, overwrite: bool = False) -> self.topk_type = getattr(sa, "vortex_topk_type", "naive") self.topk_mapping_mode = getattr(sa, "vortex_topk_mapping_mode", 0) self.topk_mapping_power = getattr(sa, "vortex_topk_mapping_power", 0.5) - self.topk_mapping_noscale = getattr(sa, "vortex_topk_mapping_noscale", False) self.topk_histogram_enabled = getattr(sa, "vortex_topk_histogram", False) device = getattr(model_runner, "device", "cpu") - # Load calibration data from .npy files when paths are provided - lut_path = getattr(sa, 'vortex_topk_mapping_lut_path', None) - if lut_path is not None: - lut_np = np.load(lut_path).astype(np.uint8) - self.topk_mapping_lut = torch.from_numpy(lut_np).to(device) - - quantiles_path = getattr(sa, 'vortex_topk_mapping_quantiles_path', None) - if quantiles_path is not None: - q_np = np.load(quantiles_path).astype(np.float32) - self.topk_mapping_quantiles = torch.from_numpy(q_np).to(device) - self.max_num_workloads = ( (self.max_num_pages // max(1, sa.vortex_lb_min_chunk_size)) + max_bs * self.num_kv_heads ) diff --git a/vortex_torch/indexer/output_func.py b/vortex_torch/indexer/output_func.py index e4424cdf..889e0682 100644 --- a/vortex_torch/indexer/output_func.py +++ b/vortex_torch/indexer/output_func.py @@ -1,10 +1,9 @@ import torch from typing import Dict, Callable, List, Optional from ..abs import vOp -from vortex_torch_C import topk_output, topk_output_sglang, topk_output_sglang_ori, topk_profile_histogram +from vortex_torch_C import topk_output, topk_output_sglang, topk_output_sglang_fused, topk_profile_histogram from .context import Context from ..abs import vTensor, FORMAT -from ..utils import UNSET # --- Module-level histogram accumulator for offline calibration --- _calibration_histograms: List[torch.Tensor] = [] @@ -91,7 +90,7 @@ class topK(vOp): FORMAT.RAGGED: { "naive": topk_output, "sglang": topk_output_sglang, - "sglang_ori": topk_output_sglang_ori, + "sglang_fused": topk_output_sglang_fused, }, } @@ -245,17 +244,7 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso assert self.impl is not None, f"{prefix}execute called before profile() (impl is None)" if self.topk_type == "sglang": - # topk_output_sglang: (x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, sparse_kv_indices, ...) - mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) - mapping_hparam = getattr(ctx, 'topk_mapping_hparam', getattr(ctx, 'topk_mapping_power', 0.5)) - mapping_lut = getattr(ctx, 'topk_mapping_lut', None) - mapping_quantiles = getattr(ctx, 'topk_mapping_quantiles', None) - mapping_noscale = getattr(ctx, 'topk_mapping_noscale', False) - # UNSET sentinel is not a valid torch.Tensor — coerce to None - if mapping_lut is UNSET: - mapping_lut = None - if mapping_quantiles is UNSET: - mapping_quantiles = None + # topk_output_sglang: unmapped baseline (no remap). self.impl( x, ctx.dense_kv_indptr, @@ -267,14 +256,14 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_bos, ctx.page_reserved_eos, ctx.max_num_pages_per_request, - mapping_mode, - mapping_hparam, - mapping_lut, - mapping_quantiles, - mapping_noscale, ) - elif self.topk_type == "sglang_ori": - # topk_output_sglang_ori: same CSR interface, no mapping params + elif self.topk_type == "sglang_fused": + # topk_output_sglang_fused: single-launch fused remap + topk. + mapping_mode = getattr(ctx, 'topk_mapping_mode', 0) + mapping_power = getattr( + ctx, 'topk_mapping_hparam', + getattr(ctx, 'topk_mapping_power', 0.5), + ) self.impl( x, ctx.dense_kv_indptr, @@ -286,6 +275,8 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso ctx.page_reserved_bos, ctx.page_reserved_eos, ctx.max_num_pages_per_request, + int(mapping_mode), + float(mapping_power), ) else: # topk_output (naive): (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, sparse_kv_indices, ...) @@ -307,11 +298,19 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso # are not permitted while a stream is being captured. if ( getattr(ctx, 'topk_histogram_enabled', False) - and self.topk_type == "sglang" + and self.topk_type in ("sglang", "sglang_fused") and not torch.cuda.is_current_stream_capturing() ): eff_bs = ctx.batch_size * ctx.num_kv_heads self.last_histograms = torch.zeros(eff_bs, 256, dtype=torch.int32, device=x.device) + hist_mode = 0 + hist_power = 0.5 + if self.topk_type == "sglang_fused": + hist_mode = int(getattr(ctx, 'topk_mapping_mode', 0)) + hist_power = float(getattr( + ctx, 'topk_mapping_hparam', + getattr(ctx, 'topk_mapping_power', 0.5), + )) topk_profile_histogram( x, ctx.dense_kv_indptr, @@ -319,12 +318,8 @@ def execute(self, x: torch.Tensor, o: torch.Tensor, ctx: Context) -> torch.Tenso eff_bs, ctx.page_reserved_bos, ctx.page_reserved_eos, - mapping_mode, - mapping_hparam, - mapping_lut, - mapping_quantiles, - mapping_noscale, - ctx.topk_val, + hist_mode, + hist_power, ) # Accumulate histograms for offline calibration _calibration_histograms.append(self.last_histograms.cpu().clone()) From aecde11194f8ce9a55a4790d27ffc7e46fe0bb8b Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 13 Apr 2026 03:15:11 -0400 Subject: [PATCH 19/24] Refactor TopK mapping and benchmarking scripts for enhanced profiling and usability - Updated autotune_topk_mapping.py to optimize hyperparameter tuning based on kernel latency. - Simplified the sweep grid and improved documentation for usage. - Enhanced bench_topk.py to expose public helpers and added CLI modes for benchmarking. - Introduced new remap functions and improved kernel integration for profiling. - Added watchdog timeout option in calibrate_topk.py for SGLang scheduler. - Removed outdated greedy_layer_search.py as part of code cleanup. --- examples/run_topk_benchmark.sh | 332 +++++++++++++-------------------- 1 file changed, 128 insertions(+), 204 deletions(-) diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh index d57e2f1c..33c6e40d 100755 --- a/examples/run_topk_benchmark.sh +++ b/examples/run_topk_benchmark.sh @@ -1,54 +1,47 @@ #!/usr/bin/env bash # ============================================================ -# TopK Benchmark +# Unified TopK Benchmark # -# Compares ALL TopK kernel variants under controlled conditions: -# Step 1: Calibrate (for modes 1/2) -# Step 2: Kernel-level latency (bench_topk.py, all 6 modes) -# Step 3: E2E accuracy (verify_algo.py) -# - Full-attention baseline first -# - Then naive, sglang mode 0/1/2/3/4 -# - Same model, same prompts, deterministic sampling -# -# Fairness improvements over verify_algo_topk_mapping.sh: -# - Full-attention baseline for absolute reference -# - All modes in one sweep (including calibrated 1/2) -# - Sequential runs on same CUDA device minimize interference -# - Deterministic sampling (temperature=0) for reproducibility -# - Results saved to a single timestamped directory +# Three-step pipeline on a single configurable model: +# Step 1: Calibrate — run the model to collect +# real-distribution histograms +# (raw_histograms.npy, lut.npy, +# quantiles.npy). +# Step 2: Latency autotune + bench — rank per-mode hparams by +# measured fused-topk kernel +# latency, then run the +# remap / topk / fused / baseline +# comparison. +# Step 3: E2E accuracy — verify_algo.py on the same +# model for the unmapped baseline +# plus each mapping mode, with +# autotuned hparams. # # Usage: -# bash run_topk_benchmark.sh [OPTIONS] -# -# Options: -# --model-name NAME HuggingFace model (default: Qwen/Qwen3-1.7B) -# --topk-val K Top-k value (default: 30) -# --trials N E2E trial count (default: 8) -# --mem FRAC GPU memory fraction (default: 0.7) -# --gpu GPU_ID CUDA device (default: 0) -# --algo NAME Sparse attention algorithm (default: block_sparse_attention) -# --skip-calibrate Reuse existing calibration data -# --skip-kernel Skip kernel-level benchmark (step 2) -# --skip-e2e Skip E2E accuracy benchmark (step 3) +# bash run_topk_benchmark.sh --gpu 0 +# bash run_topk_benchmark.sh --gpu 0 --model-name Qwen/Qwen3-8B \ +# --block-size 32 --topk-val 512 # ============================================================ set -euo pipefail -# use GPU_ID to set the GPU id you want to use -GPU_ID=4 - SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 TRIALS=8 MEM=0.7 ALGO="block_sparse_attention" +BLOCK_SIZE=16 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +SEQ_LEN=32768 +BENCHMARKS="amc23" SKIP_CALIBRATE=false SKIP_KERNEL=false SKIP_E2E=true -BENCHMARKS="amc23" # space-separated list, e.g. "amc23 aime24" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do @@ -60,9 +53,13 @@ while [[ $# -gt 0 ]]; do --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --benchmark) BENCHMARKS="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; --skip-calibrate) SKIP_CALIBRATE=true; shift ;; --skip-kernel) SKIP_KERNEL=true; shift ;; - --skip-e2e) SKIP_E2E=true; shift ;; + --skip-e2e) SKIP_E2E=false; shift ;; # --skip-e2e actually toggles it OFF (enables) *) echo "Unknown option: $1"; exit 1 ;; esac done @@ -73,123 +70,128 @@ RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) BENCH_LABEL=$(echo "${BENCHMARKS}" | tr ' ' '_') -RUN_DIR="${RESULTS_DIR}/topk_benchmark_${BENCH_LABEL}_${TIMESTAMP}" +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/topk_benchmark_${MODEL_SLUG}_${BENCH_LABEL}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" echo "============================================================" -echo "Fair Unified TopK Benchmark" -echo " Model: ${MODEL_NAME}" -echo " Algorithm: ${ALGO}" -echo " TopK: ${TOPK_VAL}" -echo " Trials: ${TRIALS}" -echo " GPU: ${GPU_ID}" -echo " Output: ${RUN_DIR}" +echo "Unified TopK Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN}" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Trials: ${TRIALS}" +echo " GPU: ${GPU_ID}" +echo " Output: ${RUN_DIR}" echo "============================================================" -# ── Step 1: Calibrate (for modes 1/2) ──────────────────────── +# ── Step 1: Calibrate ──────────────────────────────────────── CALIBRATION_DIR="${RUN_DIR}/calibration" if [ "${SKIP_CALIBRATE}" = true ] && [ -d "${CALIBRATION_DIR}" ]; then echo "" echo ">>> Step 1: SKIPPED (--skip-calibrate)" else echo "" - echo ">>> Step 1: Calibrating — collecting histograms for LUT/quantile modes" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — real topk histograms + LUT/quantiles" mkdir -p "${CALIBRATION_DIR}" python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ --vortex-module-name "${ALGO}" \ + --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" echo ">>> Step 1: Done." fi -# ── Step 2: Kernel-level latency benchmark ──────────────────── +REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" +LUT_PATH="" +Q_PATH="" +[ -f "${CALIBRATION_DIR}/lut.npy" ] && LUT_PATH="${CALIBRATION_DIR}/lut.npy" +[ -f "${CALIBRATION_DIR}/quantiles.npy" ] && Q_PATH="${CALIBRATION_DIR}/quantiles.npy" +[ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" +[ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" + +# ── Step 2: Latency autotune + remap bench ─────────────────── +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" if [ "${SKIP_KERNEL}" = true ]; then echo "" echo ">>> Step 2: SKIPPED (--skip-kernel)" else - # Step 2a: Auto-tune parametric mapping modes (must run before bench) echo "" - echo ">>> Step 2a: Auto-tuning parametric mapping hyperparameters" - AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" - REAL_HIST_ARGS="" - if [ -f "${CALIBRATION_DIR}/raw_histograms.npy" ]; then - REAL_HIST_ARGS="--real-histograms ${CALIBRATION_DIR}/raw_histograms.npy" - fi - python "${BENCH_DIR}/autotune_topk_mapping.py" \ + echo ">>> Step 2a: Auto-tuning per-mode hparams by fused-topk kernel latency" + AUTOTUNE_EXTRA=() + [ -f "${REAL_HIST_PATH}" ] && AUTOTUNE_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") + [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --topk-val "${TOPK_VAL}" \ - --batch-size 4 \ - --seq-len 32768 \ - --num-kv-heads 2 \ - ${REAL_HIST_ARGS} \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --page-size "${BLOCK_SIZE}" \ + --warmup 20 --repeat 100 \ + --collect-stats \ + "${AUTOTUNE_EXTRA[@]}" \ --output-json "${AUTOTUNE_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2a_autotune.log" - echo ">>> Step 2a: Done. Autotune results saved to ${AUTOTUNE_JSON}" + echo ">>> Step 2a: Done. Autotune saved to ${AUTOTUNE_JSON}" - # Step 2b: Kernel-level latency + histogram benchmark (using autotune params) echo "" - echo ">>> Step 2b: Kernel-level latency benchmark (all modes)" - + echo ">>> Step 2b: Remap benchmark (baseline / fused / remap / split) with autotuned hparams" BENCH_JSON="${RUN_DIR}/kernel_latency.json" - - # Build calibration args - LUT_ARGS="" - if [ -f "${CALIBRATION_DIR}/lut.npy" ]; then - LUT_ARGS="--lut-path ${CALIBRATION_DIR}/lut.npy" - fi - QUANTILES_ARGS="" - if [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then - QUANTILES_ARGS="--quantiles-path ${CALIBRATION_DIR}/quantiles.npy" - fi - - python "${BENCH_DIR}/bench_topk.py" \ - --batch-sizes 4 8 16 32 \ - --seq-lens 2048 4096 8192 16384 32768 \ + BENCH_EXTRA=() + [ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") + [ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ --topk-vals "${TOPK_VAL}" \ - --num-kv-heads 2 4 \ - --distributions normal lognormal uniform \ - --histogram \ - --hit-rate \ - --warmup 20 \ - --repeat 100 \ - ${LUT_ARGS} \ - ${QUANTILES_ARGS} \ + --page-size "${BLOCK_SIZE}" \ + --distributions normal bucket_uniform \ + --mapping-modes 0 1 2 3 6 7 8 9 10 11 13 \ --autotune-json "${AUTOTUNE_JSON}" \ + "${BENCH_EXTRA[@]}" \ + --warmup 20 --repeat 100 \ --output-json "${BENCH_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2b_kernel_bench.log" - echo ">>> Step 2b: Done. Results saved to ${BENCH_JSON}" - - # Step 2c: Per-mode distribution analysis - echo "" - echo ">>> Step 2c: Generating per-mode distribution analysis" - - python "${BENCH_DIR}/analyze_topk_distribution.py" \ - --bench-json "${BENCH_JSON}" \ - ${REAL_HIST_ARGS} \ - --output-dir "${RUN_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step2c_analyze.log" - - echo ">>> Step 2c: Done. Per-mode plots saved to ${RUN_DIR}" fi -# ── Step 3: E2E accuracy comparison ────────────────────────── +# ── Step 3: E2E accuracy ───────────────────────────────────── if [ "${SKIP_E2E}" = true ]; then echo "" - echo ">>> Step 3: SKIPPED (--skip-e2e)" + echo ">>> Step 3: SKIPPED (default). Pass --skip-e2e to toggle it ON." else echo "" echo ">>> Step 3: E2E accuracy comparison" + # Extract autotuned hparams per mode. + eval "$(python3 -c " +import json, sys +data = json.load(open(sys.argv[1])) +best = {} +for r in data: + m = r.get('mode'); lat = r.get('latency_ms') + if m is None or lat is None: continue + if m not in best or lat < best[m]['latency_ms']: + best[m] = r +for m in (3, 6, 7, 9, 10, 11, 13): + print(f'BEST_HPARAM_{m}={best.get(m, {}).get(\"param\", 0.5)}') +" "${AUTOTUNE_JSON}")" + E2E_DIR="${RUN_DIR}/e2e" mkdir -p "${E2E_DIR}" - # Helper: run verify_algo.py with common args and save output run_e2e() { - local label="$1" - shift + # $1=label, remaining args passed to verify_algo.py + local label="$1"; shift local logfile="${E2E_DIR}/${label}.log" echo "" echo " --- ${label} ---" @@ -203,122 +205,44 @@ else 2>&1 | tee "${logfile}" } - # 3a. Full-attention baseline (oracle) - run_e2e "full_attention_baseline" \ - --full-attention - - # 3b. Naive TopK - run_e2e "naive_mode0" \ - --vortex-module-name "${ALGO}" \ - --topk-type naive - - # 3c. SGLang mode 0 (no mapping) - run_e2e "sglang_mode0_none" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 0 - - # 3d. SGLang mode 1 (LUT CDF) — requires calibration - if [ -f "${CALIBRATION_DIR}/lut.npy" ]; then - run_e2e "sglang_mode1_lut_cdf" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 1 \ - --topk-mapping-lut-path "${CALIBRATION_DIR}/lut.npy" - else - echo " --- sglang_mode1_lut_cdf: SKIPPED (no lut.npy) ---" - fi - - # 3e. SGLang mode 2 (quantile) — requires calibration - if [ -f "${CALIBRATION_DIR}/quantiles.npy" ]; then - run_e2e "sglang_mode2_quantile" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 2 \ - --topk-mapping-quantiles-path "${CALIBRATION_DIR}/quantiles.npy" - else - echo " --- sglang_mode2_quantile: SKIPPED (no quantiles.npy) ---" - fi - - # 3f. SGLang mode 3 (power) - run_e2e "sglang_mode3_power" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 3 \ - --topk-mapping-power 0.5 - - # 3g. SGLang mode 4 (log) - run_e2e "sglang_mode4_log" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 4 - - # 3h. SGLang mode 6 (asinh) - run_e2e "sglang_mode6_asinh" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 6 \ - --topk-mapping-power 1.0 - - # 3i. SGLang mode 7 (log1p) - run_e2e "sglang_mode7_log1p" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 7 \ - --topk-mapping-power 1.0 - - # 3j. SGLang mode 8 (Trunc8) - run_e2e "sglang_mode8_trunc8" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 8 - - # 3k. SGLang mode 9 (Erf) - run_e2e "sglang_mode9_erf" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 9 \ - --topk-mapping-power 1.0 - - # 3l. SGLang mode 10 (Tanh) - run_e2e "sglang_mode10_tanh" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 10 \ - --topk-mapping-power 1.0 + run_mapped() { + # $1=mode $2=hparam $3=label + local mode="$1"; local hp="$2"; local label="$3" + local extra=(--vortex-module-name "${ALGO}") + if [ "${mode}" -eq 0 ]; then + extra+=(--topk-type sglang) + else + extra+=(--topk-type sglang_fused --topk-mapping-mode "${mode}" --topk-mapping-hparam "${hp}") + fi + run_e2e "${label}" "${extra[@]}" + } - # 3m. SGLang mode 11 (Subtract) - run_e2e "sglang_mode11_subtract" \ - --vortex-module-name "${ALGO}" \ - --topk-type sglang \ - --topk-mapping-mode 11 + run_e2e "full_attention_baseline" --full-attention + run_e2e "naive_topk" --vortex-module-name "${ALGO}" --topk-type naive + run_mapped 0 0.5 "sglang_m0_none" + run_mapped 3 "${BEST_HPARAM_3}" "sglang_m3_power_p${BEST_HPARAM_3}" + run_mapped 4 0.5 "sglang_m4_log" + run_mapped 6 "${BEST_HPARAM_6}" "sglang_m6_asinh_beta${BEST_HPARAM_6}" + run_mapped 7 "${BEST_HPARAM_7}" "sglang_m7_log1p_alpha${BEST_HPARAM_7}" + run_mapped 8 0.5 "sglang_m8_trunc8" + run_mapped 9 "${BEST_HPARAM_9}" "sglang_m9_erf_alpha${BEST_HPARAM_9}" + run_mapped 10 "${BEST_HPARAM_10}" "sglang_m10_tanh_alpha${BEST_HPARAM_10}" + run_mapped 11 "${BEST_HPARAM_11}" "sglang_m11_subtract_pivot${BEST_HPARAM_11}" + run_mapped 13 "${BEST_HPARAM_13}" "sglang_m13_expstretch_alpha${BEST_HPARAM_13}" echo "" echo ">>> Step 3: Done. E2E logs saved to ${E2E_DIR}/" - - # ── Summary table: extract pass@N from each log ───────────── - echo "" - echo "============================================================" - echo "E2E Accuracy Summary" - echo "============================================================" - printf "%-35s %s\n" "Configuration" "Result" - printf "%-35s %s\n" "-----------------------------------" "------" - for logfile in "${E2E_DIR}"/*.log; do - label=$(basename "${logfile}" .log) - # Extract the last line matching pass@ pattern - result=$(grep -oP 'pass@\d+\s*[=:]\s*[\d.]+' "${logfile}" | tail -1 || echo "N/A") - printf "%-35s %s\n" "${label}" "${result}" - done - echo "============================================================" fi # ── Final Summary ───────────────────────────────────────────── echo "" echo "============================================================" echo "TopK Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" echo " All results: ${RUN_DIR}" echo " Calibration: ${CALIBRATION_DIR}" +[ "${SKIP_KERNEL}" != true ] && echo " Autotune: ${AUTOTUNE_JSON}" [ "${SKIP_KERNEL}" != true ] && echo " Kernel JSON: ${RUN_DIR}/kernel_latency.json" -[ "${SKIP_KERNEL}" != true ] && echo " Per-mode: ${RUN_DIR}/distribution_comparison_m*.png, bucket_counts_m*.csv" -[ "${SKIP_E2E}" != true ] && echo " E2E logs: ${RUN_DIR}/e2e/" +[ "${SKIP_E2E}" != true ] && echo " E2E logs: ${RUN_DIR}/e2e/" echo "============================================================" From 990b3ebc125ab601cf41a8e2037815b7fa652186 Mon Sep 17 00:00:00 2001 From: UED Date: Tue, 14 Apr 2026 03:41:46 -0400 Subject: [PATCH 20/24] - Introduced topk_output_sglang_ori function for the original SGLang kernel in vortex_torch_C. - Updated setup.py to include the new source file for the original kernel. - Enhanced autotune_topk_mapping.py and bench_topk.py to support new mapping modes and original kernel integration. - Expanded the sweep grid in autotune_topk_mapping.py for improved hyperparameter tuning. - Added a new command-line argument in calibrate_topk.py for maximum total tokens to manage KV pool size. - Removed outdated remap_function_bench.sh script as part of code cleanup. --- benchmarks/autotune_topk_mapping.py | 225 ++++--- benchmarks/bench_topk.py | 491 ++++++++++++-- benchmarks/calibrate_topk.py | 12 + csrc/register.cc | 6 + csrc/register.h | 11 + csrc/topk.cu | 14 +- csrc/topk_mapping.cuh | 112 +++- csrc/topk_sglang.cu | 295 ++++++--- csrc/topk_sglang_ori.cu | 619 ++++++++++++++++++ csrc/topk_sglang_profile.cu | 64 +- examples/remap_function_bench_topk2028.sh | 252 +++++++ ...ench.sh => remap_function_bench_topk30.sh} | 57 +- examples/run_distribution_analysis.sh | 5 + examples/run_distribution_analysis_new.sh | 5 + examples/run_topk_benchmark.sh | 5 + examples/verify_algo_topk_mapping.sh | 4 + examples/verify_algo_topk_mapping_new.sh | 4 + setup.py | 1 + 18 files changed, 1901 insertions(+), 281 deletions(-) create mode 100644 csrc/topk_sglang_ori.cu create mode 100755 examples/remap_function_bench_topk2028.sh rename examples/{remap_function_bench.sh => remap_function_bench_topk30.sh} (84%) diff --git a/benchmarks/autotune_topk_mapping.py b/benchmarks/autotune_topk_mapping.py index db213213..e103953e 100644 --- a/benchmarks/autotune_topk_mapping.py +++ b/benchmarks/autotune_topk_mapping.py @@ -1,10 +1,21 @@ """ Auto-tune TopK mapping hyperparameters by profiled kernel latency. -For each (mode, hyperparameter) combo in the sweep grid, this script runs -the fused remap+topk kernel (topk_output_sglang_fused) on synthetic or -real-distribution inputs, measures end-to-end latency with CUDA events, -and picks the hyperparameter with the lowest measured latency per mode. +For each (mode, hyperparameter) combo in the sweep grid, this script picks +the hyperparameter whose remapped score distribution produces the lowest +*unfused* topk kernel latency. The measurement is a split-phase: + + 1. topk_remap_only(x, mode, power) → float32 buffer [NOT timed] + 2. topk_output_sglang(remapped) [TIMED] + +Timing only step 2 isolates the Stage-2 radix cost, which is what bucket +uniformity actually affects. The remap cost is the same constant regardless +of power, so it would only pollute the ranking. + +Non-arithmetic baselines (MAPPING_LUT_CDF=1, MAPPING_QUANTILE=2, +MAPPING_TRUNC8=8) route their mapping through compute_stage1_bin, not +apply_transform, so split-phase is a no-op for them. Those are timed via +the fused kernel and marked `timing_mode="fused_fallback"` in the output. Distribution statistics (gini, max/mean, counter-based Stage-2 cost) are still collected for diagnostics, but they do NOT drive the ranking — the @@ -25,31 +36,65 @@ import numpy as np import torch -from bench_topk import make_topk_inputs, bench_kernel, compute_histogram_stats +from bench_topk import ( + make_topk_inputs, + bench_kernel, + compute_histogram_stats, + scores_from_histogram, +) from vortex_torch_C import ( + topk_output_sglang, topk_output_sglang_fused, + topk_remap_only, topk_profile_histogram, topk_profile_counters, ) +# Modes where topk_mapping.cuh::apply_transform is a genuine value-space +# transform (power / asinh / log / log1p / erf / tanh / subtract / exp_stretch, +# plus the top-spreading shift_pow2 / shift_pow3 / linear_steep family) and +# also mode 0 (identity). For these the split-phase `remap_only + unfused +# topk` is correct. Modes 1/2/8 (LUT_CDF / QUANTILE / TRUNC8) apply their +# mapping inside compute_stage1_bin, so split-phase is a no-op. +ARITHMETIC_MODES = {0, 3, 4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20} + + # Only parametric modes need auto-tuning. Mode 0 (none) and mode 4 (log) -# have no knob; mode 0 is always the baseline. +# have no knob; mode 0 is always the baseline. Sweep grids widened so the +# autotune actually explores the tails of each transform. SWEEP_GRID: Dict[int, List[float]] = { - 3: [0.1, 0.25, 0.5, 0.75, 0.9], # power: p - 6: [0.1, 0.5, 1.0, 2.0, 4.0], # asinh: beta - 7: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0], # log1p: alpha - 9: [0.1, 0.5, 1.0, 2.0, 4.0], # erf: alpha - 10: [0.1, 0.5, 1.0, 2.0, 4.0], # tanh: alpha - 11: [-1.0, -0.5, 0.0, 0.5, 1.0], # subtract: pivot (free hparam) - 13: [0.5, 1.0, 2.0, 4.0, 8.0], # exp_stretch: alpha + 3: [0.1, 0.5, 1.0, 2.0, 4.0, 5.0, 9.0], # power: p + 6: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # asinh: beta + 7: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # log1p: alpha + 9: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # erf: alpha + 10: [0.1, 0.5, 1.0, 2.0, 4.0, 8.0, 16.0], # tanh: alpha + 11: [-2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], # subtract: pivot + 13: [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0], # exp_stretch: alpha + 15: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # shift_pow2: pivot + 16: [-4.0, -2.0, -1.0, -0.5, 0.0, 0.5, 1.0, 2.0], # shift_pow3: pivot (widened) + 17: [0.5, 1.0, 2.0, 4.0, 8.0, 16.0, 32.0], # linear_steep: k + 18: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # half_square: pivot + 19: [-1.0, -0.5, -0.25, 0.0, 0.25, 0.5, 1.0], # half_cube: pivot + # dense_mant clamp: sweep a wide range because real attention scores + # can span [-400, +200] on some models (raw logits), not just [0, 1]. + 20: [0.0, 1.0, 5.0, 10.0, 20.0, 50.0, 100.0], # dense_mant: clamp pivot } -PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", 11: "pivot", 13: "alpha"} +PARAM_NAME = { + 3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", + 15: "pivot", 16: "pivot", 17: "k", + 18: "pivot", 19: "pivot", + 20: "clamp", +} MODE_NAMES = { 0: "none", 1: "lut_cdf", 2: "quantile", 3: "power", 4: "log", 6: "asinh", 7: "log1p", 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", 13: "exp_stretch", + 15: "shift_pow2", 16: "shift_pow3", 17: "linear_steep", + 18: "half_square", 19: "half_cube", + 20: "dense_mant", } # Non-parametric modes — no knob to sweep; timed once as a reference point. @@ -59,54 +104,8 @@ # ---------- Real-distribution score generation ---------- - -def _key_to_fp16(key: int) -> np.float16: - """Invert convert_to_uint8's sign-flip for a single 16-bit key.""" - bits = (key & 0x7FFF) if key >= 0x8000 else ((~key) & 0xFFFF) - return np.array([bits], dtype=np.uint16).view(np.float16)[0] - - -def _build_bin_range_table(): - """Return per-bin (lo, hi) fp16 value tables for all 256 radix bins.""" - all_bits = np.arange(65536, dtype=np.uint16) - all_fp16 = all_bits.view(np.float16) - keys = np.where( - (all_bits & 0x8000).astype(bool), - (~all_bits).astype(np.uint16), - all_bits | np.uint16(0x8000), - ) - bins = (keys >> 8).astype(np.uint8) - all_f32 = all_fp16.astype(np.float32) - valid = np.isfinite(all_f32) - bin_lo = np.full(256, np.inf, dtype=np.float32) - bin_hi = np.full(256, -np.inf, dtype=np.float32) - for b in range(256): - mask = (bins == b) & valid - if mask.any(): - vals = all_f32[mask] - bin_lo[b] = vals.min() - bin_hi[b] = vals.max() - empty = bin_lo > bin_hi - for b in np.where(empty)[0]: - val = float(_key_to_fp16((int(b) << 8) | 0x80)) - bin_lo[b] = val - bin_hi[b] = val - return bin_lo, bin_hi - - -def _scores_from_histogram(histogram: np.ndarray, total_pages: int, device="cuda") -> torch.Tensor: - bin_lo, bin_hi = _build_bin_range_table() - counts = histogram.astype(np.float64) - total = counts.sum() - if total == 0: - return torch.zeros(total_pages, 1, 1, dtype=torch.bfloat16, device=device) - probs = counts / total - bin_indices = np.random.choice(256, size=total_pages, p=probs) - lo = bin_lo[bin_indices] - hi = bin_hi[bin_indices] - rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) - scores_f32 = lo + rand * (hi - lo) - return torch.from_numpy(scores_f32).to(torch.bfloat16).reshape(total_pages, 1, 1).to(device) +# _build_bin_range_table / scores_from_histogram now live in bench_topk.py +# so both autotune and bench_topk draw scores from the same sampler. def _make_real_inputs(args, histogram: np.ndarray) -> dict: @@ -125,10 +124,13 @@ def _make_real_inputs(args, histogram: np.ndarray) -> dict: ) dense_kv_indices = torch.arange(total_dense, dtype=torch.int32, device="cuda") sparse_kv_indices = torch.zeros(eff_bs * sparse_per_seg, dtype=torch.int32, device="cuda") - x = _scores_from_histogram(histogram, total_dense) + x = scores_from_histogram(histogram, total_dense, device="cuda", + score_dtype=torch.bfloat16) + remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(x.shape) return { "x": x, + "remapped": remapped, "dense_kv_indptr": dense_kv_indptr, "sparse_kv_indptr": sparse_kv_indptr, "dense_kv_indices": dense_kv_indices, @@ -139,9 +141,20 @@ def _make_real_inputs(args, histogram: np.ndarray) -> dict: } +def _ensure_remapped_buffer(inputs: dict) -> torch.Tensor: + """Lazy-allocate a float32 buffer matching x.shape for the split-phase.""" + buf = inputs.get("remapped") + if buf is None: + x = inputs["x"] + buf = torch.empty(x.numel(), dtype=torch.float32, device=x.device).reshape(x.shape) + inputs["remapped"] = buf + return buf + + # ---------- Latency-based evaluation ---------- def _time_fused(inputs, args, mode: int, power: float) -> dict: + """Fused remap+topk kernel latency (used as fallback for modes 1/2/8).""" eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] inputs["sparse_kv_indices"].zero_() @@ -167,6 +180,59 @@ def _time_fused(inputs, args, mode: int, power: float) -> dict: warmup=args.warmup, repeat=args.repeat) +def _time_unfused_on_remapped(inputs, args, mode: int, power: float) -> dict: + """Time the unfused topk kernel on pre-remapped scores. + + For mode 0 the original scores are used directly. For every other + arithmetic mode we run topk_remap_only once (not timed) into a + pre-allocated float32 buffer, then time topk_output_sglang on that + buffer with bench_kernel's warmup + repeat loop. This isolates the + Stage-2 radix cost from the remap pass. + """ + eff_bs = inputs["eff_batch_size"] + pages_per_seg = inputs["num_pages_per_seg"] + + if mode == 0: + src = inputs["x"] + else: + remapped = _ensure_remapped_buffer(inputs) + topk_remap_only( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + float(power), + ) + torch.cuda.synchronize() + src = remapped + + inputs["sparse_kv_indices"].zero_() + call_args = ( + src, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, + args.topk_val, + args.reserved_bos, + args.reserved_eos, + pages_per_seg, + ) + return bench_kernel(topk_output_sglang, call_args, + warmup=args.warmup, repeat=args.repeat) + + +def _time_mode(inputs, args, mode: int, power: float) -> tuple: + """Returns (latency_dict, timing_mode_str).""" + if mode in ARITHMETIC_MODES: + return _time_unfused_on_remapped(inputs, args, mode, power), "unfused_on_remapped" + return _time_fused(inputs, args, mode, power), "fused_fallback" + + def _collect_diagnostics(inputs, args, mode: int, power: float) -> dict: """Optional distribution/counter stats for reporting only (post-timing).""" eff_bs = inputs["eff_batch_size"] @@ -209,6 +275,11 @@ def _collect_diagnostics(inputs, args, mode: int, power: float) -> dict: diag["threshold_bin_mean"] = c[:, 0].mean().item() diag["num_equal_mean"] = c[:, 2].mean().item() diag["refine_rounds_mean"] = c[:, 4].mean().item() + # selected_from_thr = topk_val - num_above (clamped >= 0). Used as + # a tie-breaker by bench_topk._load_autotune_hparams when several + # modes have indistinguishable latency. + sel_from_thr = (float(args.topk_val) - c[:, 1]).clamp(min=0.0) + diag["selected_from_thr_mean"] = sel_from_thr.mean().item() return diag @@ -218,13 +289,14 @@ def _run_sweep(args, inputs, dist_label: str) -> List[dict]: # Baselines: time them but their param is fixed. for mode, power in BASELINES: - lat = _time_fused(inputs, args, mode, power) + lat, tmode = _time_mode(inputs, args, mode, power) entry = { "mode": mode, "mode_name": MODE_NAMES.get(mode, f"m{mode}"), "param_name": "(baseline)", "param": power, "distribution": dist_label, + "timing_mode": tmode, "latency_ms": lat["mean_ms"], "latency_median_ms": lat["median_ms"], "latency_min_ms": lat["min_ms"], @@ -232,21 +304,22 @@ def _run_sweep(args, inputs, dist_label: str) -> List[dict]: entry.update(_collect_diagnostics(inputs, args, mode, power)) results.append(entry) print( - f" mode={mode:>2d} ({MODE_NAMES[mode]:>5s}) baseline " - f" latency={lat['mean_ms']:.4f} ms" + f" mode={mode:>2d} ({MODE_NAMES[mode]:>10s}) baseline " + f" [{tmode:>20s}] latency={lat['mean_ms']:.4f} ms" ) # Parametric sweep, one (mode, param) combo at a time. for mode, values in SWEEP_GRID.items(): pname = PARAM_NAME[mode] for val in values: - lat = _time_fused(inputs, args, mode, float(val)) + lat, tmode = _time_mode(inputs, args, mode, float(val)) entry = { "mode": mode, "mode_name": MODE_NAMES.get(mode, f"m{mode}"), "param_name": pname, "param": float(val), "distribution": dist_label, + "timing_mode": tmode, "latency_ms": lat["mean_ms"], "latency_median_ms": lat["median_ms"], "latency_min_ms": lat["min_ms"], @@ -254,8 +327,8 @@ def _run_sweep(args, inputs, dist_label: str) -> List[dict]: entry.update(_collect_diagnostics(inputs, args, mode, float(val))) results.append(entry) print( - f" mode={mode:>2d} ({MODE_NAMES[mode]:>5s}) {pname}={val:<6.3f} " - f" latency={lat['mean_ms']:.4f} ms" + f" mode={mode:>2d} ({MODE_NAMES[mode]:>10s}) {pname}={val:<6.3f} " + f" [{tmode:>20s}] latency={lat['mean_ms']:.4f} ms" ) return results @@ -320,19 +393,11 @@ def main(): help="Path to .npy float32[256] quantile table for MAPPING_QUANTILE (mode 2).") args = parser.parse_args() + # Modes 1 (LUT_CDF) and 2 (Quantile) are no longer evaluated — they + # don't use topk_mapping::apply_transform (their mapping is done inside + # compute_stage1_bin) and are kept out of the comparison entirely. args._mapping_lut = None args._mapping_quantiles = None - # Include modes 1/2 as baselines when calibration tables are provided. - if args.lut_path: - lut_np = np.load(args.lut_path).astype(np.uint8) - args._mapping_lut = torch.from_numpy(lut_np).cuda() - BASELINES.append((1, 0.5)) - print(f"[autotune] loaded LUT from {args.lut_path}") - if args.quantiles_path: - q_np = np.load(args.quantiles_path).astype(np.float32) - args._mapping_quantiles = torch.from_numpy(q_np).cuda() - BASELINES.append((2, 0.5)) - print(f"[autotune] loaded quantiles from {args.quantiles_path}") real_histogram: Optional[np.ndarray] = None if args.real_histograms: diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 4bd5becb..f0860c94 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -23,14 +23,25 @@ import torch from vortex_torch_C import ( - topk_output, - topk_output_sglang, # unmapped baseline - topk_output_sglang_fused, # fused remap + topk - topk_remap_only, # standalone remap + topk_output, # full CUB BlockRadixSort topk (max 4096 pages/seg) + topk_output_sglang, # 2-stage radix approximate topk (unmapped baseline) + topk_output_sglang_fused, # fused remap + 2-stage radix topk + topk_output_sglang_ori, # original SGLang reference kernel + topk_remap_only, # standalone value-space remap topk_profile_histogram, topk_profile_counters, ) +# topk_output's template ladder tops out at 8192 pages per segment +# (see topk.cu::topk_output, branches up to <= 8192). Runs larger than +# that hit TORCH_CHECK(false). +TOPK_OUTPUT_MAX_PAGES = 8192 + +# The ori kernel has TopK baked in at compile time. If setup.py was built +# with a different value, calls will fail; this is the topk_val that +# matches the current build of topk_sglang_ori.cu. +TOPK_ORI_BAKED_IN = 30 + MAPPING_MODE_NAMES = { 0: "None", @@ -45,32 +56,136 @@ 10: "Tanh", 11: "Subtract", 13: "ExpStretch", + 15: "ShiftPow2", + 16: "ShiftPow3", + 17: "LinearSteep", + 18: "HalfSquare", + 19: "HalfCube", + 20: "DenseMant", } +# Modes whose value-space transform is a real apply_transform() pass. Modes +# 1 (LUT_CDF), 2 (QUANTILE) and 8 (TRUNC8) apply their mapping inside +# compute_stage1_bin, not apply_transform — so `topk_remap_only` cannot +# reproduce them (the fp32 buffer would just contain the raw values). For +# those modes the split-phase numbers are N/A; only the fused kernel is a +# meaningful reference. +ARITHMETIC_MODES = {0, 3, 4, 6, 7, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20} + + +_AUTOTUNE_TIE_TOLERANCE_MS = 0.0002 # ≈ CUDA event noise floor at this kernel size + def _load_autotune_hparams(path: str) -> Dict[int, float]: """Load per-mode best hyperparameters from an autotune_results.json. The JSON is produced by autotune_topk_mapping.py and contains a list of - {mode, param, latency_ms, ...} entries. For each mode we pick the entry - with the lowest measured latency and return {mode: best_param}. - - Modes with no parametric sweep (0=None, 4=Log) return a dummy 0.5; the - caller should override to taste. + {mode, param, latency_ms, num_equal_mean, selected_from_thr_mean, ...} + entries. For each mode we group all sweep entries, find the lowest + latency, then break ties (within `_AUTOTUNE_TIE_TOLERANCE_MS`) by: + + 1. Smallest `num_equal_mean` (= thr_size). Stage-2 cost is O(thr_size), + so a smaller threshold bin is a better proxy for real fused + latency than the noisy `latency_ms` measurement. + 2. Smallest `selected_from_thr_mean`. How many pages the topk has to + pull from the threshold bin during refinement. + 3. Lowest `latency_ms` again (final fallback). + + Modes with no parametric sweep (0=None, 4=Log) return a dummy 0.5; + the caller should override to taste. """ with open(path) as f: data = json.load(f) - best: Dict[int, dict] = {} + grouped: Dict[int, list] = {} for r in data: m = r.get("mode") lat = r.get("latency_ms") if m is None or lat is None: continue - if m not in best or lat < best[m]["latency_ms"]: - best[m] = r + grouped.setdefault(m, []).append(r) + + best: Dict[int, dict] = {} + for m, entries in grouped.items(): + min_lat = min(e["latency_ms"] for e in entries) + contenders = [ + e for e in entries + if e["latency_ms"] - min_lat <= _AUTOTUNE_TIE_TOLERANCE_MS + ] + # Tie-breakers: lowest num_equal_mean, then lowest sel_thr, + # then lowest latency. Missing diagnostic fields → +inf so they + # lose tie-breaks (we still keep them as fallback candidates). + def _rank_key(e): + return ( + e.get("num_equal_mean", float("inf")), + e.get("selected_from_thr_mean", float("inf")), + e["latency_ms"], + ) + best[m] = min(contenders, key=_rank_key) + return {m: float(r["param"]) for m, r in best.items()} +def _key_to_fp16(key: int) -> np.float16: + """Invert convert_to_uint8's sign-flip for a single 16-bit key.""" + bits = (key & 0x7FFF) if key >= 0x8000 else ((~key) & 0xFFFF) + return np.array([bits], dtype=np.uint16).view(np.float16)[0] + + +def build_bin_range_table(): + """Per-bin (lo, hi) fp16 value tables for the 256 Stage-1 radix bins. + + Shared by the real-distribution samplers in bench_topk.py and + autotune_topk_mapping.py so both scripts generate identical inputs. + """ + all_bits = np.arange(65536, dtype=np.uint16) + all_fp16 = all_bits.view(np.float16) + keys = np.where( + (all_bits & 0x8000).astype(bool), + (~all_bits).astype(np.uint16), + all_bits | np.uint16(0x8000), + ) + bins = (keys >> 8).astype(np.uint8) + all_f32 = all_fp16.astype(np.float32) + valid = np.isfinite(all_f32) + bin_lo = np.full(256, np.inf, dtype=np.float32) + bin_hi = np.full(256, -np.inf, dtype=np.float32) + for b in range(256): + mask = (bins == b) & valid + if mask.any(): + vals = all_f32[mask] + bin_lo[b] = vals.min() + bin_hi[b] = vals.max() + empty = bin_lo > bin_hi + for b in np.where(empty)[0]: + val = float(_key_to_fp16((int(b) << 8) | 0x80)) + bin_lo[b] = val + bin_hi[b] = val + return bin_lo, bin_hi + + +def scores_from_histogram( + histogram: np.ndarray, + total_pages: int, + device: str = "cuda", + score_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Sample `total_pages` scores whose Stage-1 bucket distribution matches + the given 256-bin histogram (produced by calibration). Each bucket is + sampled uniformly over the fp16 range that maps into it.""" + bin_lo, bin_hi = build_bin_range_table() + counts = histogram.astype(np.float64) + total = counts.sum() + if total == 0: + return torch.zeros(total_pages, 1, 1, dtype=score_dtype, device=device) + probs = counts / total + bin_indices = np.random.choice(256, size=total_pages, p=probs) + lo = bin_lo[bin_indices] + hi = bin_hi[bin_indices] + rand = np.random.uniform(0, 1, size=total_pages).astype(np.float32) + scores_f32 = lo + rand * (hi - lo) + return torch.from_numpy(scores_f32).to(score_dtype).reshape(total_pages, 1, 1).to(device) + + def make_topk_inputs( batch_size: int, num_kv_heads: int, @@ -81,9 +196,15 @@ def make_topk_inputs( reserved_eos: int, score_dtype: torch.dtype, distribution: str = "normal", + real_histogram: np.ndarray = None, device: str = "cuda", ) -> dict: - """Synthesize CSR-formatted paged attention inputs for kernel timing.""" + """Synthesize CSR-formatted paged attention inputs for kernel timing. + + When `real_histogram` is provided, scores are drawn from that 256-bin + distribution (ignoring `distribution`) so the benchmark sees the same + Stage-1 bucket distribution as the calibrated model. + """ eff_batch_size = batch_size * num_kv_heads num_pages_per_seg = math.ceil(seq_len / page_size) total_dense_pages = eff_batch_size * num_pages_per_seg @@ -101,24 +222,25 @@ def make_topk_inputs( dense_kv_indices = torch.arange(total_dense_pages, dtype=torch.int32, device=device) sparse_kv_indices = torch.zeros(total_sparse_pages, dtype=torch.int32, device=device) - if distribution == "normal": - x = torch.randn(total_dense_pages, 1, 1, device=device) + if real_histogram is not None: + x = scores_from_histogram(real_histogram, total_dense_pages, device=device, + score_dtype=score_dtype) + elif distribution == "normal": + x = torch.randn(total_dense_pages, 1, 1, device=device).to(score_dtype) elif distribution == "lognormal": - x = torch.randn(total_dense_pages, 1, 1, device=device).exp() + x = torch.randn(total_dense_pages, 1, 1, device=device).exp().to(score_dtype) elif distribution == "uniform": - x = torch.rand(total_dense_pages, 1, 1, device=device) + x = torch.rand(total_dense_pages, 1, 1, device=device).to(score_dtype) elif distribution == "bucket_uniform": # Uniform across all 256 fp16 radix buckets. Random uint16 bit # patterns → interpret as fp16. NaN/Inf patterns collapse to ±0. raw_bits = torch.randint(0, 65536, (total_dense_pages,), dtype=torch.int32, device=device) abs_bits = raw_bits & 0x7FFF raw_bits[abs_bits >= 0x7C00] = raw_bits[abs_bits >= 0x7C00] & 0x8000 - x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1) + x = raw_bits.to(torch.int16).view(torch.float16).float().reshape(total_dense_pages, 1, 1).to(score_dtype) else: raise ValueError(f"Unknown distribution: {distribution}") - x = x.to(score_dtype) - return { "x": x, "dense_kv_indptr": dense_kv_indptr, @@ -185,10 +307,9 @@ def compute_histogram_stats(histograms: torch.Tensor) -> dict: def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, power: float) -> dict: - """Run topk_profile_counters once and aggregate threshold-bin stats. - - Profile kernel is invoked AFTER all latency measurements have finished, - so the counter writes never contaminate timing. + """Run topk_profile_counters + topk_profile_histogram once and aggregate + threshold-bin / bucket-distribution stats. Profile kernels run AFTER all + latency measurements, so their writes never contaminate timing. """ eff_bs = inputs["eff_batch_size"] counter_buf = torch.zeros(eff_bs, 6, dtype=torch.int32, device="cuda") @@ -214,6 +335,40 @@ def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, p ) torch.cuda.synchronize() c = counter_buf.float() + + # Run the 256-bin histogram profile to compute the rank_target_bins + # metric: how many bins ABOVE the threshold bin (i.e. the bins whose + # pages are selected without Stage-2 refinement) actually contain + # selected pages, and the mean pages-per-such-bin. + hist_buf = torch.zeros(eff_bs, 256, dtype=torch.int32, device="cuda") + topk_profile_histogram( + inputs["x"], + inputs["dense_kv_indptr"], + hist_buf, + eff_bs, + args.reserved_bos, + args.reserved_eos, + mode, + power, + lut_t, + q_t, + ) + torch.cuda.synchronize() + + thr_idx = counter_buf[:, 0].to(torch.int64) # [eff_bs] + hist = hist_buf.to(torch.int64) # [eff_bs, 256] + bin_ids = torch.arange(256, device="cuda", dtype=torch.int64).unsqueeze(0) # [1, 256] + above_mask = bin_ids > thr_idx.unsqueeze(1) # [eff_bs, 256] + above_populated = ((hist > 0) & above_mask).sum(dim=1).float() # bins >thr with any pages + pages_above = (hist * above_mask.to(torch.int64)).sum(dim=1).float() # total pages in those bins + # Mean pages per populated above-threshold bin (per-segment, then + # averaged). Guard against divide-by-zero. + pages_per_bin = torch.where( + above_populated > 0, + pages_above / above_populated, + torch.zeros_like(above_populated), + ) + # Selected from threshold bin = topk_val - num_above (clamped >= 0). sel_from_thr = (float(topk_val) - c[:, 1]).clamp(min=0.0) return { @@ -225,6 +380,9 @@ def _collect_threshold_stats(inputs, topk_val, pages_per_seg, args, mode: int, p "selected_from_thr_mean": sel_from_thr.mean().item(), "selected_from_thr_max": sel_from_thr.max().item(), "refine_rounds_mean": c[:, 4].mean().item(), + # Rank-target metrics: how the top pages are actually spread. + "above_bins_mean": above_populated.mean().item(), + "pages_per_above_bin_mean": pages_per_bin.mean().item(), } @@ -241,6 +399,7 @@ def _resolve_hparam(args, mode: int) -> float: def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, distribution, modes: List[int]) -> dict: """Time baseline, fused, and split-phase for each mode at one config.""" + real_hist = getattr(args, "_real_histogram", None) if distribution == "real" else None inputs = make_topk_inputs( batch_size=batch_size, num_kv_heads=num_kv_heads, @@ -250,13 +409,17 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, reserved_bos=args.reserved_bos, reserved_eos=args.reserved_eos, score_dtype=torch.bfloat16, - distribution=distribution, + distribution=distribution if distribution != "real" else "normal", + real_histogram=real_hist, ) eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] total_dense = inputs["x"].numel() - # Baseline: unmapped topk. + # Baseline = unmapped topk_output_sglang (CUB two-stage radix, the + # kernel every mapped mode's split-phase ends up calling). This is + # the `base_us` column and also what the `None` row reports, so + # None's topk_us == base_us by construction. baseline_args = ( inputs["x"], inputs["dense_kv_indptr"], @@ -268,7 +431,55 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, inputs["sparse_kv_indices"].zero_() baseline = bench_kernel(topk_output_sglang, baseline_args, args.warmup, args.repeat) + # Optional extra row: the full CUB BlockRadixSort topk from topk.cu. + # This is a "true naive" — exact sort, no bucketing tricks — for A/B + # against the 2-stage approximate baseline. Only runs when pages_per_seg + # fits the kernel's template ladder (<= TOPK_OUTPUT_MAX_PAGES = 4096). + naive_ms = None + if pages_per_seg <= TOPK_OUTPUT_MAX_PAGES: + naive_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["dense_kv_indices"], # NOTE: topk_output arg order differs + inputs["sparse_kv_indptr"], # from topk_output_sglang + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + naive_ms = bench_kernel( + topk_output, naive_args, args.warmup, args.repeat + )["mean_ms"] + + # Optional extra row: the original SGLang kernel from topk_sglang_ori.cu, + # compiled with TopK=TOPK_ORI_BAKED_IN. Only runs when topk_val matches + # that constant; otherwise the row is skipped with a warning. It is NOT + # used as the baseline — this is a separate A/B point so you can see the + # ori-vs-naive gap at a glance. + sglang_ori_ms = None + if topk_val == TOPK_ORI_BAKED_IN: + ori_indices = torch.empty(eff_bs, TOPK_ORI_BAKED_IN, + dtype=torch.int32, device="cuda") + ori_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + ori_indices, + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + sglang_ori_ms = bench_kernel( + topk_output_sglang_ori, ori_args, args.warmup, args.repeat + )["mean_ms"] + # Pre-allocate the float32 buffer used for the split-phase (remap → baseline). + # Split-phase remapped buffer is **float32** to preserve Stage-2 + # refinement precision. The fused kernel computes transforms in + # fp32 internally (so its Stage-2 sub-bin keys carry transform- + # dependent bits in positions [15:0]); a narrower remapped buffer + # (bf16 or fp16) would zero those bits on round-trip and change + # the Stage-2 tie-break ordering vs the fused path. fp32 is the + # only lossless choice. The kernel supports bf16 output too (see + # topk_remap_only's dispatch table) for experimental paths, but we + # don't use it here because correctness matters more than the + # small memory-bandwidth win. remapped = torch.empty(total_dense, dtype=torch.float32, device="cuda").reshape(inputs["x"].shape) config = { @@ -279,10 +490,87 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "distribution": distribution, "pages_per_seg": pages_per_seg, "baseline_ms": baseline["mean_ms"], + "naive_ms": naive_ms, + "sglang_ori_ms": sglang_ori_ms, "modes": [], } + # Naive row — full CUB BlockRadixSort from topk.cu. No mapping, no + # remap, no fused. Only populated when pages_per_seg fits the kernel. + if naive_ms is not None: + config["modes"].append({ + "mode": -2, # sentinel so ranking/autotune skip it + "mode_name": "Naive", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": naive_ms, + "split_total_ms": None, + "fused_ms": None, + "threshold_bin_mean": 0.0, + "threshold_bin_max": 0.0, + "num_above_mean": 0.0, + "threshold_bin_size_mean": 0.0, + "threshold_bin_size_max": 0.0, + "selected_from_thr_mean": 0.0, + "selected_from_thr_max": 0.0, + "refine_rounds_mean": 0.0, + "above_bins_mean": 0.0, + "pages_per_above_bin_mean": 0.0, + }) + + # The None row is a pass-through to the naive baseline: no remap, no + # fused, and topk_us == base_us by construction. Distribution metrics + # are populated by running the profile kernels with mode=0 so the user + # can see the unmapped Stage-1 bucket layout as a reference. + none_stats = _collect_threshold_stats( + inputs, topk_val, pages_per_seg, args, mode=0, power=0.5 + ) + config["modes"].append({ + "mode": 0, + "mode_name": "None", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": baseline["mean_ms"], + "split_total_ms": None, + "fused_ms": None, + **none_stats, + }) + + # Extra row for the original SGLang kernel — only populated when the + # build's baked-in TopK matches topk_val. Also a pass-through (no + # remap, no fused); topk_us is the ori kernel latency. + if sglang_ori_ms is not None: + config["modes"].append({ + "mode": -1, # sentinel so ranking/autotune skip it + "mode_name": "sglang_ori", + "power": 0.5, + "remap_ms": None, + "topk_after_remap_ms": sglang_ori_ms, + "split_total_ms": None, + "fused_ms": None, + "threshold_bin_mean": 0.0, + "threshold_bin_max": 0.0, + "num_above_mean": 0.0, + "threshold_bin_size_mean": 0.0, + "threshold_bin_size_max": 0.0, + "selected_from_thr_mean": 0.0, + "selected_from_thr_max": 0.0, + "refine_rounds_mean": 0.0, + "above_bins_mean": 0.0, + "pages_per_above_bin_mean": 0.0, + }) + else: + print(f"[bench-remap] sglang_ori row SKIPPED: topk_val={topk_val} != " + f"TOPK_ORI_BAKED_IN ({TOPK_ORI_BAKED_IN}). Rebuild topk_sglang_ori.cu " + f"with a matching TopK to enable the row.") + for mode in modes: + # Mode 0 is already emitted as the `None` row above (pass-through + # to the ori baseline with no remap/fused). Skip to avoid a + # duplicate row and a spurious fused-mode-0 measurement. + if mode == 0: + continue + power = _resolve_hparam(args, mode) lut_t = getattr(args, "_mapping_lut", None) if mode == 1 else None @@ -299,30 +587,43 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, inputs["sparse_kv_indices"].zero_() fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) - # Split-phase timing: first the standalone remap, then the unmapped - # topk on the remapped buffer. - remap_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - remapped, - eff_bs, args.reserved_bos, args.reserved_eos, - mode, power, - ) - remap_only = bench_kernel(topk_remap_only, remap_args, args.warmup, args.repeat) + # Split-phase timing is only meaningful for arithmetic modes. + # MAPPING_LUT_CDF / QUANTILE / TRUNC8 apply their mapping inside + # compute_stage1_bin, which topk_remap_only cannot reproduce, so we + # report N/A for the split-phase fields and rely on the fused kernel + # as the only valid reference latency. + if mode in ARITHMETIC_MODES: + remap_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + remapped, + eff_bs, args.reserved_bos, args.reserved_eos, + mode, power, + ) + remap_only = bench_kernel(topk_remap_only, remap_args, args.warmup, args.repeat) + + # Populate the remapped buffer once so the unfused-topk warmup + # iterations don't read stale data. + topk_remap_only(*remap_args) + torch.cuda.synchronize() + split_topk_args = ( + remapped, + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + ) + inputs["sparse_kv_indices"].zero_() + split_topk = bench_kernel(topk_output_sglang, split_topk_args, args.warmup, args.repeat) - split_topk_args = ( - remapped, - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, - ) - # Run remap once so the buffer is populated for warmup of topk-on-remapped. - topk_remap_only(*remap_args) - torch.cuda.synchronize() - inputs["sparse_kv_indices"].zero_() - split_topk = bench_kernel(topk_output_sglang, split_topk_args, args.warmup, args.repeat) + remap_ms = remap_only["mean_ms"] + topk_after_remap_ms = split_topk["mean_ms"] + split_total_ms = remap_ms + topk_after_remap_ms + else: + remap_ms = None + topk_after_remap_ms = None + split_total_ms = None # Counter collection is run AFTER all timing measurements for this mode # so it cannot affect the timings. @@ -332,9 +633,9 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "mode": mode, "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), "power": power, - "remap_ms": remap_only["mean_ms"], - "topk_after_remap_ms": split_topk["mean_ms"], - "split_total_ms": remap_only["mean_ms"] + split_topk["mean_ms"], + "remap_ms": remap_ms, + "topk_after_remap_ms": topk_after_remap_ms, + "split_total_ms": split_total_ms, "fused_ms": fused["mean_ms"], **stats, } @@ -345,8 +646,9 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, def _print_remap_table(results: List[dict]) -> None: header = ( - f"{'mode':<12s} {'remap_us':>9s} {'topk_us':>9s} {'split_us':>9s} " - f"{'fused_us':>9s} {'base_us':>9s} {'thr_bin':>7s} {'thr_size':>8s} {'sel_thr':>7s}" + f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " + f"{'fused_ms':>9s} {'base_ms':>9s} {'thr_bin':>7s} {'thr_size':>8s} " + f"{'sel_thr':>7s} {'abv_bins':>8s} {'pg/bin':>7s}" ) for cfg in results: banner = ( @@ -355,36 +657,65 @@ def _print_remap_table(results: List[dict]) -> None: f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']}]" ) print(banner) - print(" Baseline: mapping_mode=0 (raw fp16 bucketing)") + extra_notes = [] + if cfg.get("naive_ms") is not None: + extra_notes.append("Naive row = topk.cu (CUB full sort)") + if cfg.get("sglang_ori_ms") is not None: + extra_notes.append("sglang_ori row = topk_sglang_ori.cu") + notes_str = "" + if extra_notes: + notes_str = " | " + " | ".join(extra_notes) + print(f" Baseline: topk_sglang.cu (CUB two-stage){notes_str}") print(header) print("-" * len(header)) - base_us = cfg["baseline_ms"] * 1000.0 + base_ms = cfg["baseline_ms"] for row in cfg["modes"]: - label = f"{row['mode_name']}(p={row['power']})" if row["mode"] != 0 else "None" + if row["mode"] == 0: + label = "None" + elif row["mode"] == -1: + label = row.get("mode_name", "sglang_ori") + elif row["mode"] == -2: + label = row.get("mode_name", "Naive") + else: + label = f"{row['mode_name']}(p={row['power']})" + def _fmt(v): + return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" + fused_str = _fmt(row.get("fused_ms")) print( - f"{label:<12s} " - f"{row['remap_ms'] * 1000.0:9.2f} " - f"{row['topk_after_remap_ms'] * 1000.0:9.2f} " - f"{row['split_total_ms'] * 1000.0:9.2f} " - f"{row['fused_ms'] * 1000.0:9.2f} " - f"{base_us:9.2f} " + f"{label:<14s} " + f"{_fmt(row['remap_ms'])} " + f"{_fmt(row['topk_after_remap_ms'])} " + f"{_fmt(row['split_total_ms'])} " + f"{fused_str} " + f"{base_ms:9.4f} " f"{row['threshold_bin_mean']:7.1f} " f"{row['threshold_bin_size_mean']:8.1f} " - f"{row['selected_from_thr_mean']:7.1f}" + f"{row['selected_from_thr_mean']:7.1f} " + f"{row.get('above_bins_mean', 0.0):8.1f} " + f"{row.get('pages_per_above_bin_mean', 0.0):7.1f}" ) def _run_remap_bench(args) -> None: modes = [int(m) for m in args.mapping_modes] - if 0 not in modes: - modes = [0] + modes + # Mode 0 is emitted as the "None" row from _remap_bench_one_config + # itself (pass-through to the ori baseline). Drop any user-supplied 0 + # to avoid a duplicate row. + modes = [m for m in modes if m != 0] + + distributions = list(args.distributions) + if getattr(args, "_real_histogram", None) is not None: + if "real" not in distributions: + distributions.append("real") + print(f"[remap-bench] 'real' distribution enabled " + f"(histogram total count = {int(args._real_histogram.sum())})") results = [] for bs in args.batch_sizes: for heads in args.num_kv_heads: for seq_len in args.seq_lens: for topk_val in args.topk_vals: - for dist in args.distributions: + for dist in distributions: cfg = _remap_bench_one_config( args, bs, heads, seq_len, topk_val, dist, modes, ) @@ -401,17 +732,23 @@ def _run_remap_bench(args) -> None: def _run_latency_sweep(args) -> None: """Simple baseline-vs-fused latency sweep (no split-phase, no counters).""" modes = [int(m) for m in args.mapping_modes] + distributions = list(args.distributions) + if getattr(args, "_real_histogram", None) is not None and "real" not in distributions: + distributions.append("real") results = [] for bs in args.batch_sizes: for heads in args.num_kv_heads: for seq_len in args.seq_lens: for topk_val in args.topk_vals: - for dist in args.distributions: + for dist in distributions: + real_hist = args._real_histogram if dist == "real" else None inputs = make_topk_inputs( batch_size=bs, num_kv_heads=heads, seq_len=seq_len, page_size=args.page_size, topk_val=topk_val, reserved_bos=args.reserved_bos, reserved_eos=args.reserved_eos, - score_dtype=torch.bfloat16, distribution=dist, + score_dtype=torch.bfloat16, + distribution=dist if dist != "real" else "normal", + real_histogram=real_hist, ) eff_bs = inputs["eff_batch_size"] pages_per_seg = inputs["num_pages_per_seg"] @@ -469,7 +806,14 @@ def main(): p.add_argument("--topk-vals", type=int, nargs="+", default=[30]) p.add_argument("--distributions", type=str, nargs="+", default=["normal"], - choices=["normal", "lognormal", "uniform", "bucket_uniform"]) + choices=["normal", "lognormal", "uniform", "bucket_uniform", "real"], + help="Synthetic distributions. Use 'real' (or --real-histograms) to " + "sample scores from a calibrated raw_histograms.npy.") + p.add_argument("--real-histograms", type=str, default=None, + help="Path to raw_histograms.npy from calibrate_topk.py. When set, a " + "'real' distribution is appended to the sweep so every " + "(mode, hparam) combo is also timed on the calibrated score " + "distribution.") p.add_argument("--mapping-modes", type=int, nargs="+", default=[0, 3, 6, 7], help="Mapping modes to sweep (0=None, 3=Power, 6=Asinh, 7=Log1p, etc.)") @@ -503,6 +847,13 @@ def main(): for m, v in sorted(args._autotune_hparams.items()): print(f" mode {m:>2d} -> {v}") + args._real_histogram = None + if args.real_histograms: + raw = np.load(args.real_histograms) + args._real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + print(f"[real] loaded calibrated histogram from {args.real_histograms} " + f"(shape={raw.shape} → [256] aggregate)") + args._mapping_lut = None args._mapping_quantiles = None if args.lut_path: diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py index e3524c1f..4914133f 100644 --- a/benchmarks/calibrate_topk.py +++ b/benchmarks/calibrate_topk.py @@ -38,6 +38,17 @@ def main(): parser.add_argument("--topk-val", type=int, default=30) parser.add_argument("--page-size", type=int, default=16) parser.add_argument("--mem", type=float, default=0.7) + parser.add_argument( + "--max-total-tokens", + type=int, + default=1048576, + help="Hard cap on KV pool token slots (ServerArgs.max_total_tokens). " + "Block-sparse profiling uses a small bytes/token estimate, so the auto " + "budget can be huge on large GPUs; VTXGraphAttnBackend then allocates " + "dense bf16 sparse_prefill K/V buffers proportional to this cap (~4 KiB per " + "token per buffer). For offline calibration, a few hundred K1M tokens " + "is usually enough.", + ) parser.add_argument("--kv-cache-dtype", type=str, default="auto") parser.add_argument("--topk-type", type=str, default="sglang") parser.add_argument("--num-prompts", type=int, default=16, @@ -76,6 +87,7 @@ def main(): vortex_module_name=args.vortex_module_name, vortex_max_seq_lens=12288, mem_fraction_static=args.mem, + max_total_tokens=args.max_total_tokens, kv_cache_dtype=args.kv_cache_dtype, vortex_topk_type=args.topk_type, vortex_topk_mapping_mode=0, # Use mode 0 during calibration diff --git a/csrc/register.cc b/csrc/register.cc index cc201c98..8aa5aea1 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -14,6 +14,12 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("eff_batch_size"), py::arg("topk_val"), py::arg("reserved_bos"), py::arg("reserved_eos"), py::arg("max_num_pages")); + m.def("topk_output_sglang_ori", &topk_output_sglang_ori, + py::arg("x"), py::arg("dense_kv_indptr"), + py::arg("indices_out"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); m.def("topk_output_sglang_fused", &topk_output_sglang_fused, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), diff --git a/csrc/register.h b/csrc/register.h index e86a9638..afdb97f0 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -98,6 +98,17 @@ const int64_t reserved_eos, const int64_t max_seq_lengths ); +void topk_output_sglang_ori( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +at::Tensor& indices_out, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +); + void topk_output_sglang_fused( const at::Tensor& x, const at::Tensor& dense_kv_indptr, diff --git a/csrc/topk.cu b/csrc/topk.cu index 70d2000a..081bddf4 100644 --- a/csrc/topk.cu +++ b/csrc/topk.cu @@ -196,8 +196,20 @@ const int64_t max_num_pages reserved_bos, reserved_eos ); + } else if (max_num_pages <= 8192){ + TopKOutput_BF16_Kernel<512, 16><<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + sparse_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + sparse_kv_indices.data_ptr(), + topk_val, + reserved_bos, + reserved_eos + ); } else { - TORCH_CHECK(false); + TORCH_CHECK(false, "topk_output: max_num_pages=", max_num_pages, + " exceeds the supported template ladder (8192)."); } } \ No newline at end of file diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index 447e5397..c645acbf 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -33,7 +33,24 @@ enum TopKMappingMode { MAPPING_ERF = 9, // erf(alpha * x) MAPPING_TANH = 10, // tanh(alpha * x) MAPPING_SUBTRACT = 11, // x - pivot, with pivot = power_exp (free hyperparameter) - MAPPING_EXP_STRETCH = 13, // exp(alpha * x) + MAPPING_EXP_STRETCH = 13, // exp(alpha * x) + // Top-spreading transforms (see CLAUDE.md / remap bench plan): + // amplify differences in the high-score region so the top-K values + // occupy multiple Stage-1 bins instead of collapsing into one. + MAPPING_SHIFT_POW2 = 15, // sign(x - p) * (x - p)^2 [p = power_exp] + MAPPING_SHIFT_POW3 = 16, // (x - p)^3 [p = power_exp] + MAPPING_LINEAR_STEEP = 17, // x + k * max(x, 0) [k = power_exp] + // One-sided spread: collapse below-pivot values into a single bin so + // every above-pivot page gets its own slice of the 256-bin histogram. + MAPPING_HALF_SQUARE = 18, // max(x - p, 0)^2 [p = power_exp] + MAPPING_HALF_CUBE = 19, // max(x - p, 0)^3 [p = power_exp] + // Bit-level remap: identity value transform, but the Stage-1 bucket + // function in fast_topk_clean_fused switches to a mantissa-heavy bit + // slice (bits [23:16] of convert_to_uint32) that gives 128 sub-bins + // per exponent slot instead of 4. Zero per-element compute overhead; + // the "remap" is the bucket change. Monotonic within 2 adjacent + // fp32 exponent slots. + MAPPING_DENSE_MANT = 20, // identity; bucketing handled in fused kernel }; struct TopKMappingParams { @@ -75,24 +92,99 @@ __device__ __forceinline__ float transform_exp_stretch(float x, float alpha) { return expf(z); } +// Signed squared distance from a pivot. ~3 ops (1 sub, 1 mul, 1 copysign). +// Quadratically amplifies differences between values far from pivot so the +// top-K region gets spread across multiple Stage-1 bins. +__device__ __forceinline__ float transform_shift_pow2(float x, float pivot) { + const float d = x - pivot; + return copysignf(d * d, d); +} + +// Signed cubic of distance from pivot. ~3 ops (1 sub, 2 mul; odd function so +// no copysign). Steeper growth than pow2 for even tighter top-K clusters. +__device__ __forceinline__ float transform_shift_pow3(float x, float pivot) { + const float d = x - pivot; + return d * d * d; +} + +// Half-range linear stretch: positive values get multiplied by (1 + k), +// negative values pass through untouched. ~2 ops (fmax + fma). For softmax- +// style attention scores (which are non-negative after softmax), k = 8..16 +// shifts the positive fp16 exponent up by 3..4 slots and empties out the +// collision at the top of the distribution. +__device__ __forceinline__ float transform_linear_steep(float x, float k) { + return fmaf(k, fmaxf(x, 0.0f), x); +} + +// One-sided shifted square: values below pivot collapse to 0 (they all end +// up in the same low Stage-1 bin), above-pivot values are squared so their +// differences amplify quadratically. ~2 ops (fmax + mul). The whole 256-bin +// histogram becomes dedicated to the top slice of the distribution. +__device__ __forceinline__ float transform_half_square(float x, float pivot) { + const float d = fmaxf(x - pivot, 0.0f); + return d * d; +} + +// One-sided shifted cube: like half_square but cubic. ~3 ops. Best when the +// top-K region is even more tightly clustered and needs steeper amplification. +__device__ __forceinline__ float transform_half_cube(float x, float pivot) { + const float d = fmaxf(x - pivot, 0.0f); + return d * d * d; +} + +// Compile-time templated dispatcher. When the caller knows the mapping mode +// at template-instantiation time, this lets the compiler fully inline the +// transform into the Stage-1 inner loop and eliminate the runtime switch +// that `apply_transform` would otherwise perform per element. Used by the +// per-mode specializations of `fast_topk_clean_fused` in topk_sglang.cu. +template +__device__ __forceinline__ float apply_transform_tmpl(float x, float p) { + if constexpr (MODE == MAPPING_POWER) return transform_power(x, p); + else if constexpr (MODE == MAPPING_LOG) return transform_log(x); + else if constexpr (MODE == MAPPING_ASINH) return transform_asinh(x, p); + else if constexpr (MODE == MAPPING_LOG1P) return transform_log1p(x, p); + else if constexpr (MODE == MAPPING_ERF) return transform_erf(x, p); + else if constexpr (MODE == MAPPING_TANH) return transform_tanh(x, p); + else if constexpr (MODE == MAPPING_SUBTRACT) return x - p; + else if constexpr (MODE == MAPPING_EXP_STRETCH) return transform_exp_stretch(x, p); + else if constexpr (MODE == MAPPING_SHIFT_POW2) return transform_shift_pow2(x, p); + else if constexpr (MODE == MAPPING_SHIFT_POW3) return transform_shift_pow3(x, p); + else if constexpr (MODE == MAPPING_LINEAR_STEEP) return transform_linear_steep(x, p); + else if constexpr (MODE == MAPPING_HALF_SQUARE) return transform_half_square(x, p); + else if constexpr (MODE == MAPPING_HALF_CUBE) return transform_half_cube(x, p); + else if constexpr (MODE == MAPPING_DENSE_MANT) return fmaxf(x, p); + else return x; // NONE / TRUNC8 +} + // Pure element-wise dispatcher. Returns the *float value* after the transform. // For bin-selection modes (LUT_CDF / QUANTILE) this is identity: the mapping // happens in compute_stage1_bin() below instead of via a float transform, so // Stage-2 tie-breaking uses the raw score bits for those modes. __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { switch (params.mode) { - case MAPPING_POWER: return transform_power(x, params.power_exp); - case MAPPING_LOG: return transform_log(x); - case MAPPING_ASINH: return transform_asinh(x, params.power_exp); - case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); - case MAPPING_ERF: return transform_erf(x, params.power_exp); - case MAPPING_TANH: return transform_tanh(x, params.power_exp); - case MAPPING_SUBTRACT: return x - params.power_exp; - case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + case MAPPING_POWER: return transform_power(x, params.power_exp); + case MAPPING_LOG: return transform_log(x); + case MAPPING_ASINH: return transform_asinh(x, params.power_exp); + case MAPPING_LOG1P: return transform_log1p(x, params.power_exp); + case MAPPING_ERF: return transform_erf(x, params.power_exp); + case MAPPING_TANH: return transform_tanh(x, params.power_exp); + case MAPPING_SUBTRACT: return x - params.power_exp; + case MAPPING_EXP_STRETCH: return transform_exp_stretch(x, params.power_exp); + case MAPPING_SHIFT_POW2: return transform_shift_pow2(x, params.power_exp); + case MAPPING_SHIFT_POW3: return transform_shift_pow3(x, params.power_exp); + case MAPPING_LINEAR_STEEP: return transform_linear_steep(x, params.power_exp); + case MAPPING_HALF_SQUARE: return transform_half_square(x, params.power_exp); + case MAPPING_HALF_CUBE: return transform_half_cube(x, params.power_exp); + // MAPPING_DENSE_MANT clamps small/negative values to `power_exp` + // (default 0.5) so the subsequent dense bit bucket in the fused + // kernel sees a narrow 1–2 exponent window of positive values. + // Values at/below the clamp all hash to the lowest bin, which + // is always below the topk threshold in practice. + case MAPPING_DENSE_MANT: return fmaxf(x, params.power_exp); case MAPPING_LUT_CDF: case MAPPING_QUANTILE: case MAPPING_TRUNC8: - default: return x; // NONE / TRUNC8 / LUT_CDF / QUANTILE + default: return x; // NONE / TRUNC8 / LUT_CDF / QUANTILE } } diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 2466f570..fa9c8250 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -99,6 +99,23 @@ __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); } +// Mantissa-heavy Stage-1 bucket for MAPPING_DENSE_MANT. Returns bits +// [23:16] of the sign-adjusted float32 key = 1 exp LSB + 7 top +// mantissa bits. This yields 128 mantissa sub-bins per exp slot (vs +// 4 in the current fp16 scheme — 32× more resolution) and is strictly +// monotonic across 2 adjacent fp32 exponent slots (factor-of-4 value +// range). Designed for the common case where the top-K scores cluster +// tightly: softmax-attention outputs on Qwen / Llama typically live +// in ~1 exp slot of magnitude near the top. Values with exponents +// outside the 2-slot monotonic window collide with lower bins, which +// only causes a correctness issue if top-K elements span more than +// 2 exp slots — verified empirically before shipping. +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + // ---- Vortex additions ---- template @@ -632,7 +649,7 @@ __device__ void fast_topk_clean( // benchmarking kernel, the remapped Stage-2 ordering is acceptable. // No pre-pass, no LUT, no shared-memory mapping state. // ====================================================================== -template +template __device__ void fast_topk_clean_fused( const ScoreT* __restrict__ input, int* __restrict__ index, @@ -651,34 +668,48 @@ __device__ void fast_topk_clean_fused( alignas(128) __shared__ int f_threshold_bin_id; alignas(128) __shared__ int f_num_input[2]; - // Shared-memory tables for MAPPING_LUT_CDF / MAPPING_QUANTILE. Loaded - // once at kernel entry and read per element in Stage 1. Other modes - // leave them untouched. - __shared__ uint8_t s_mapping_lut[256]; - __shared__ float s_mapping_quantiles[256]; + // Per-element Stage-1 bin cache. Pass 1 of Stage 1 writes one byte per + // element; pass 2 reads it back so each element only pays a single + // apply_transform + global score read instead of two. Sized to the + // maximum `pages_per_seg` the bench drivers use (topk=2048 config has + // seq_len=32768 / page_size=8 = 4096 pages per segment; topk=30 has + // 2048). Shrinking from 8192 to 4096 freed 4 KB of static SMEM per + // block, which lifts occupancy from 5 → 6 blocks/SM on B200. + constexpr int kFusedMaxLen = 4096; + __shared__ uint8_t s_bins[kFusedMaxLen]; auto& f_histogram = f_histogram_buf[0]; extern __shared__ int f_input_idx[][SMEM_INPUT_SIZE]; const int tx = threadIdx.x; - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } + // MODE is a compile-time template parameter, so every comparison below + // becomes a constant-folded `if constexpr` branch. The dense bucket + // path (MAPPING_DENSE_MANT) stays in the kernel but is completely + // elided when MODE != MAPPING_DENSE_MANT, and the value-space transform + // path stays in place for standard modes. LUT_CDF / QUANTILE are not + // supported by this templated kernel (they were dropped from the bench + // comparison earlier). + constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); if (tx < RADIX + 1) f_histogram[tx] = 0; __syncthreads(); - // Stage 1: LUT/QUANTILE do a shared-memory lookup, everything else - // applies the element-wise transform then buckets via convert_to_uint8. + // Stage 1 pass 1: read each score from global, compute the Stage-1 + // bin via the compile-time-dispatched transform, cache it in s_bins so + // pass 2 can skip the second global read. With MODE known at compile + // time, apply_transform_tmpl inlines to just the chosen + // transform's instructions — no runtime switch overhead. for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const float raw = vortex_to_float(input[idx + row_start]); - const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + int bin; + if constexpr (use_dense_bucket) { + bin = static_cast(convert_to_uint8_dense(remapped)); + } else { + bin = static_cast(convert_to_uint8(remapped)); + } + s_bins[idx] = static_cast(bin); ::atomicAdd(&f_histogram[bin], 1); } __syncthreads(); @@ -712,10 +743,11 @@ __device__ void fast_topk_clean_fused( topk -= f_histogram[threshold_bin + 1]; if (topk == 0) { + // Shortcut: every page above threshold gets selected. Read the bin + // from the cache so we don't re-touch global memory or recompute + // apply_transform. for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const float raw = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast( - compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + const int bin = static_cast(s_bins[idx]); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&f_counter, 1); index[pos] = idx; @@ -728,20 +760,33 @@ __device__ void fast_topk_clean_fused( if (tx < RADIX + 1) f_histogram[tx] = 0; __syncthreads(); + // Stage 1 pass 2: read the cached bin from SMEM. For elements + // outside the threshold bin we skip the global-memory load AND the + // apply_transform call entirely. Only the ~thr_size threshold-bin + // candidates re-read raw and re-apply the templated transform to + // compute the sub-bin needed for Stage-2 refinement. + // + // Sub-bin shift selection (compile-time constant): + // - standard modes: Stage-1 used fp16 top-8-bit bucketing, so + // Stage-2 round 0 refines on uint32 bits [31:24] (the most + // significant bits not captured by the fp16 bucket). + // - MAPPING_DENSE_MANT: Stage-1 used bits [23:16], so the next + // useful discriminator is bits [15:8]. Skipping to offset 8 + // directly avoids two wasted Stage-2 rounds. + constexpr int sub_bin_offset_start = use_dense_bucket ? 8 : 24; for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform(raw, mapping); - const auto bin = static_cast( - compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + const int bin = static_cast(s_bins[idx]); if (bin > threshold_bin) { const auto pos = ::atomicAdd(&f_counter, 1); index[pos] = idx; } else if (bin == threshold_bin) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); const auto pos = ::atomicAdd(&f_num_input[0], 1); if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { f_input_idx[0][pos] = idx; const auto b32 = convert_to_uint32(remapped); - const auto sub_bin = (b32 >> 24) & 0xFF; + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; ::atomicAdd(&f_histogram[sub_bin], 1); } } @@ -749,9 +794,17 @@ __device__ void fast_topk_clean_fused( __syncthreads(); } - // stage 2: refine on raw bits of the remapped value + // stage 2: refine on raw bits of the remapped value. The per-round + // bit offset matches the sub_bin shift chosen above: standard modes + // start at offset 24 (bits [31:24]) and step down by 8 per round; + // MAPPING_DENSE_MANT starts at offset 8 (bits [15:8]) because Stage 1 + // already consumed bits [23:16] in the dense bucket. Both values are + // compile-time constants since MODE is a template parameter. + constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; + constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; #pragma unroll 4 for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; __shared__ int f_last_remain; const auto r_idx = round % 2; @@ -772,9 +825,9 @@ __device__ void fast_topk_clean_fused( if (topk == 0) { for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = f_input_idx[r_idx][i]; - const auto offset = 24 - round * 8; + const auto offset = stage2_offset_start - round * 8; const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform(raw, mapping); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&f_counter, 1); @@ -790,14 +843,18 @@ __device__ void fast_topk_clean_fused( for (int i = tx; i < num_input; i += BLOCK_SIZE) { const auto idx = f_input_idx[r_idx][i]; const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform(raw, mapping); - const auto offset = 24 - round * 8; + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto offset = stage2_offset_start - round * 8; const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&f_counter, 1); index[pos] = idx; } else if (bin == threshold_bin) { - if (round == 3) { + // Last refinement round: we have no more discriminator bits + // below the current offset, so emit any remaining elements as + // "tie-break fallback" via f_last_remain (ensures topk is met + // even when thr_size > sel_thr at the finest granularity). + if (round == stage2_max_rounds - 1) { const auto pos = ::atomicAdd(&f_last_remain, -1); if (pos > 0) { index[target_k - pos] = idx; @@ -855,7 +912,7 @@ void TopKOutput_Clean_Kernel( } } -template +template __global__ __launch_bounds__(kThreadsPerBlock) void TopKOutput_Fused_Kernel( const ScoreT* __restrict__ score, @@ -882,7 +939,7 @@ void TopKOutput_Fused_Kernel( + page_reserved_bos; __shared__ int s_indices[VORTEX_MAX_TOPK]; - fast_topk_clean_fused(score_blk, s_indices, 0, nblk, topk_val, mapping); + fast_topk_clean_fused(score_blk, s_indices, 0, nblk, topk_val, mapping); __syncthreads(); const int tx = threadIdx.x; @@ -891,16 +948,32 @@ void TopKOutput_Fused_Kernel( } } +// Inverse of vortex_to_float: narrow a float back to ScoreT for the +// bf16-output remap path so the subsequent topk kernel can read half +// the bytes of a fp32 remapped buffer. +template +__device__ __forceinline__ T float_to_vortex(float x); +template <> +__device__ __forceinline__ float float_to_vortex(float x) { return x; } +template <> +__device__ __forceinline__ __nv_bfloat16 float_to_vortex<__nv_bfloat16>(float x) { + return __float2bfloat16(x); +} + // Remap-only kernel: applies the element-wise transform to each score // in the [dense_kv_indptr[b] + reserved_bos, dense_kv_indptr[b+1] - reserved_eos) -// range and writes the result into a float32 output tensor. Used by -// the split-phase benchmark (remap → unmapped topk). -template +// range and writes the result into an output tensor (OutT = float or +// bf16). Used by the split-phase benchmark (remap → unmapped topk). +// Writing bf16 halves memory bandwidth on the output and on the +// subsequent topk read; precision-wise it's lossless for the Stage-1 +// 8-bit bucket because fp16/bf16 both discard more mantissa than the +// bucket uses. +template __global__ __launch_bounds__(kThreadsPerBlock) void TopKRemapOnly_Kernel( const ScoreT* __restrict__ score, const int* __restrict__ dense_kv_indptr, - float* __restrict__ remapped, + OutT* __restrict__ remapped, const int page_reserved_bos, const int page_reserved_eos, const TopKMappingParams mapping) @@ -914,10 +987,11 @@ void TopKRemapOnly_Kernel( if (nblk <= 0) return; const ScoreT* __restrict__ score_blk = score + start; - float* __restrict__ remap_blk = remapped + start; + OutT* __restrict__ remap_blk = remapped + start; for (int i = tx; i < nblk; i += kThreadsPerBlock) { - remap_blk[i] = apply_transform(vortex_to_float(score_blk[i]), mapping); + const float y = apply_transform(vortex_to_float(score_blk[i]), mapping); + remap_blk[i] = float_to_vortex(y); } } @@ -1118,6 +1192,18 @@ void topk_output_sglang_fused( "topk_output_sglang_fused: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + // Caller contract: max_num_pages must be <= 4096, the static SMEM + // `s_bins` cache size inside the templated fused kernel. The bench + // drivers stay within this bound; no runtime check is emitted in + // the hot path. + + // The `mapping_lut` / `mapping_quantiles` optional tensors are + // retained in the pybind signature for API backward compatibility + // but are ignored: the templated fused kernel drops the LUT_CDF / + // QUANTILE code paths entirely. + (void)mapping_lut; + (void)mapping_quantiles; + CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); CHECK_CUDA(sparse_kv_indptr); @@ -1125,51 +1211,66 @@ void topk_output_sglang_fused( CHECK_CUDA(sparse_kv_indices); TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); + mapping.mode = static_cast(mapping_mode); mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; + mapping.lut = nullptr; mapping.quantiles = nullptr; - if (mapping_lut.has_value()) { - const auto& lut = mapping_lut.value(); - CHECK_CUDA(lut); - TORCH_CHECK(lut.dim() == 1 && lut.size(0) == 256 && lut.scalar_type() == at::ScalarType::Byte, - "mapping_lut must be a 1D uint8 tensor of size 256"); - mapping.lut = lut.data_ptr(); - } - if (mapping_quantiles.has_value()) { - const auto& q = mapping_quantiles.value(); - CHECK_CUDA(q); - TORCH_CHECK(q.dim() == 1 && q.size(0) == 256 && q.scalar_type() == at::ScalarType::Float, - "mapping_quantiles must be a 1D float32 tensor of size 256"); - mapping.quantiles = q.data_ptr(); - } dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + // Each mapping mode compiles to its own kernel specialization so + // apply_transform_tmpl is fully inlined (no runtime switch on + // mode in the inner loop). The wrapper's outer dispatch is a one- + // time per-call cost, negligible relative to the kernel runtime. + #define VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, kSmem>(); \ + TopKOutput_Fused_Kernel<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + topk_val, reserved_bos, reserved_eos, mapping); \ + } while (0) + + #define VORTEX_DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping.mode) { \ + case MAPPING_NONE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_TRUNC8: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_TRUNC8); break; \ + case MAPPING_ERF: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP:VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + case MAPPING_DENSE_MANT: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_DENSE_MANT); break; \ + default: \ + TORCH_CHECK(false, "topk_output_sglang_fused: unsupported mapping_mode ", mapping.mode); \ + } \ + } while (0) + if (x.scalar_type() == at::ScalarType::BFloat16) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Fused_Kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, mapping); + VORTEX_DISPATCH_MODE(__nv_bfloat16, reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); } else if (x.scalar_type() == at::ScalarType::Float) { - setup_kernel_smem_once, kSmem>(); - TopKOutput_Fused_Kernel<<>>( - x.data_ptr(), - dense_kv_indptr.data_ptr(), - sparse_kv_indptr.data_ptr(), - dense_kv_indices.data_ptr(), - sparse_kv_indices.data_ptr(), - topk_val, reserved_bos, reserved_eos, mapping); + VORTEX_DISPATCH_MODE(float, x.data_ptr()); } else { TORCH_CHECK(false, "topk_output_sglang_fused: unsupported dtype ", x.scalar_type()); } + #undef VORTEX_DISPATCH_MODE + #undef VORTEX_DISPATCH_FUSED + const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "topk_output_sglang_fused kernel failed: ", ::cudaGetErrorString(result)); @@ -1183,7 +1284,7 @@ void topk_output_sglang_fused( void topk_remap_only( const at::Tensor& x, const at::Tensor& dense_kv_indptr, - at::Tensor& remapped, // float32, same numel as x + at::Tensor& remapped, // float32 or bfloat16, same numel as x const int64_t eff_batch_size, const int64_t reserved_bos, const int64_t reserved_eos, @@ -1193,35 +1294,57 @@ void topk_remap_only( CHECK_CUDA(x); CHECK_CUDA(dense_kv_indptr); CHECK_CUDA(remapped); - TORCH_CHECK(remapped.scalar_type() == at::ScalarType::Float, - "remapped output must be float32"); + TORCH_CHECK(remapped.scalar_type() == at::ScalarType::Float + || remapped.scalar_type() == at::ScalarType::BFloat16, + "remapped output must be float32 or bfloat16"); TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); + mapping.mode = static_cast(mapping_mode); mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; + mapping.lut = nullptr; mapping.quantiles = nullptr; dim3 nblks(eff_batch_size); dim3 nthreads(kThreadsPerBlock); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - if (x.scalar_type() == at::ScalarType::BFloat16) { - TopKRemapOnly_Kernel<__nv_bfloat16><<>>( + // Four-way dispatch on (input dtype, output dtype). bf16→bf16 is the + // new "batch pre-transform" path that halves memory bandwidth vs the + // fp32 output: the remap writes half the bytes and the subsequent + // topk_output_sglang reads half the bytes. Precision is preserved + // because Stage-1 bucketing only uses the top 8 bits of an fp16 key + // which both fp32 and bf16 capture. + #define VORTEX_DISPATCH_REMAP(IN_CPP, OUT_CPP, IN_PTR_EXPR, OUT_PTR_EXPR) \ + TopKRemapOnly_Kernel<<>>( \ + IN_PTR_EXPR, dense_kv_indptr.data_ptr(), OUT_PTR_EXPR, \ + reserved_bos, reserved_eos, mapping) + + const bool in_bf16 = (x.scalar_type() == at::ScalarType::BFloat16); + const bool in_fp32 = (x.scalar_type() == at::ScalarType::Float); + const bool out_bf16 = (remapped.scalar_type() == at::ScalarType::BFloat16); + + if (in_bf16 && out_bf16) { + VORTEX_DISPATCH_REMAP(__nv_bfloat16, __nv_bfloat16, reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), - dense_kv_indptr.data_ptr(), - remapped.data_ptr(), - reserved_bos, reserved_eos, mapping); - } else if (x.scalar_type() == at::ScalarType::Float) { - TopKRemapOnly_Kernel<<>>( + reinterpret_cast<__nv_bfloat16*>(remapped.data_ptr())); + } else if (in_bf16 && !out_bf16) { + VORTEX_DISPATCH_REMAP(__nv_bfloat16, float, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + remapped.data_ptr()); + } else if (in_fp32 && out_bf16) { + VORTEX_DISPATCH_REMAP(float, __nv_bfloat16, x.data_ptr(), - dense_kv_indptr.data_ptr(), - remapped.data_ptr(), - reserved_bos, reserved_eos, mapping); + reinterpret_cast<__nv_bfloat16*>(remapped.data_ptr())); + } else if (in_fp32 && !out_bf16) { + VORTEX_DISPATCH_REMAP(float, float, + x.data_ptr(), + remapped.data_ptr()); } else { TORCH_CHECK(false, "topk_remap_only: unsupported dtype ", x.scalar_type()); } + #undef VORTEX_DISPATCH_REMAP + const auto result = cudaGetLastError(); TORCH_CHECK(result == cudaSuccess, "topk_remap_only kernel failed: ", ::cudaGetErrorString(result)); diff --git a/csrc/topk_sglang_ori.cu b/csrc/topk_sglang_ori.cu new file mode 100644 index 00000000..55a99b21 --- /dev/null +++ b/csrc/topk_sglang_ori.cu @@ -0,0 +1,619 @@ +/** + * @NOTE: This file is adapted from + * https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_v32/topk_selector.py + * We: + * 1. adapt from tilelang to pure cuda + * 2. optimize the performance a little + * 3. fix the potential illegal memory access + */ + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + namespace { + + // NOTE: TopK is a compile-time constant here because shared-memory + // allocations inside the transform kernels depend on it. We drop it to + // 30 to match the vortex benchmark's --topk-val 30 configuration. The + // transform kernels (decode/prefill/prefill_ragged) still carry a manual + // unroll that assumes TopK==2048; that code path is unreachable from the + // bench (we only invoke fast_topk_interface), so the corresponding + // static_asserts have been removed below. + constexpr int TopK = 30; + constexpr int kThreadsPerBlock = 1024; + + #ifdef USE_ROCM + // On ROCm, the per-workgroup LDS budget depends on the target arch, so we inject a + // per-arch value from `setup_rocm.py` via `-DSGL_TOPK_DYNAMIC_SMEM_BYTES=...`. + #ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES + constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); + #else + constexpr size_t kSmem = 48 * 1024; // bytes + #endif + #else + // Reduced from 128KB to 32KB to improve occupancy. + // Each radix pass needs at most ~TopK candidates in the threshold bin, + // so 4K entries per round (2 rounds = 8K entries = 32KB) is sufficient. + constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) + #endif + + struct FastTopKParams { + const float* __restrict__ input; // [B, input_stride] + const int32_t* __restrict__ row_starts; // [B] + int32_t* __restrict__ indices; // [B, TopK] + int32_t* __restrict__ lengths; // [B] + int64_t input_stride; + }; + + // when length <= TopK, we can directly write the indices + __device__ void naive_topk_cuda(const float* __restrict__ score, int32_t* __restrict__ indice, int32_t length) { + const auto tid = threadIdx.x; + for (int i = tid; i < TopK; i += kThreadsPerBlock) { + indice[i] = (i < length) ? i : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform( + const float* __restrict__ score, + int32_t length, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + dst_page_table[i] = (i < length) ? src_page_table[i] : -1; + } + } + + // keep the first `length` entries, set others to -1 + __device__ void naive_topk_transform_ragged( + const float* __restrict__ score, int32_t length, int32_t* __restrict__ topk_indices_ragged, int32_t offset) { + const auto tid = threadIdx.x; + for (auto i = tid; i < TopK; i += kThreadsPerBlock) { + topk_indices_ragged[i] = (i < length) ? static_cast(i) + offset : -1; + } + } + + __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + __device__ void fast_topk_cuda_tl(const float* __restrict__ input, int* __restrict__ index, int row_start, int length) { + // An optimized topk kernel copied from tilelang kernel + // We assume length > TopK here, or it will crash + int topk = TopK; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + // allocate for two rounds + extern __shared__ int s_input_idx[][SMEM_INPUT_SIZE]; + + const int tx = threadIdx.x; + + // stage 1: 8bit coarse histogram + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = convert_to_uint8(input[idx + row_start]); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) { + value += s_histogram_buf[k][tx + j]; + } + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto bin = static_cast(convert_to_uint8(input[idx + row_start])); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const auto raw_input = input[idx + row_start]; + const auto bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + /// NOTE: (dark) fuse the histogram computation here + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx[0][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> 24) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + // stage 2: refine with 8bit radix passes + #pragma unroll 4 + for (int round = 0; round < 4; ++round) { + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + // clip here to prevent overflow + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input : int(SMEM_INPUT_SIZE); + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) { + s_histogram[tx] = 0; + } + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx[r_idx][i]; + const auto raw_input = input[idx + row_start]; + const auto offset = 24 - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == 3) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + index[TopK - pos] = idx; + } + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + /// NOTE: (dark) fuse the histogram computation here + s_input_idx[r_idx ^ 1][pos] = idx; + const auto bin = convert_to_uint32(raw_input); + const auto sub_bin = (bin >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // topk + void topk_kernel(const FastTopKParams params) { + const auto& [input, row_starts, indices, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto indice = indices + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_cuda(score, indice, length); + } else { + return fast_topk_cuda_tl(score, indice, row_start, length); + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // decode + void topk_transform_decode_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride) { + const auto& [input, _1, _2, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = 0; + const auto length = lengths[bid]; + const auto src_page_entry = src_page_table + bid * src_stride; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill + void topk_transform_prefill_kernel( + const FastTopKParams params, + int32_t* __restrict__ dst_page_table, + const int32_t* __restrict__ src_page_table, + const int64_t src_stride, + const int32_t* __restrict__ cu_seqlens_q, + const int64_t prefill_bs) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto length = lengths[bid]; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto dst_page_entry = dst_page_table + bid * TopK; + const auto score = input + bid * input_stride; + + /// NOTE: prefill bs is usually small, we can just use a simple loop here + /// We ensure that last cu_seqlens is equal to number of blocks launched + __shared__ const int32_t* s_src_page_entry; + if (C10_LIKELY(prefill_bs <= kThreadsPerBlock)) { + if (tid < prefill_bs) { + if (bid >= cu_seqlens_q[tid] && bid < cu_seqlens_q[tid + 1]) { + s_src_page_entry = src_page_table + tid * src_stride; + } + } + } else { + for (int64_t i = tid; i < prefill_bs; i += kThreadsPerBlock) { + if (bid >= cu_seqlens_q[i] && bid < cu_seqlens_q[i + 1]) { + s_src_page_entry = src_page_table + i * src_stride; + } + } + } + __syncthreads(); + const auto src_page_entry = s_src_page_entry; + + if (length <= TopK) { + return naive_topk_transform(score, length, dst_page_entry, src_page_entry); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_page_entry[idx_0] = src_page_entry[pos_0]; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_page_entry[idx_1] = src_page_entry[pos_1]; + } + } + + __global__ __launch_bounds__(kThreadsPerBlock) // prefill, ragged kv + void topk_transform_prefill_ragged_kernel( + const FastTopKParams params, + int32_t* __restrict__ topk_indices_ragged, + const int32_t* __restrict__ topk_indices_offset) { + const auto& [input, row_starts, _, lengths, input_stride] = params; + const auto bid = static_cast(blockIdx.x); + const auto tid = threadIdx.x; + const auto row_start = row_starts == nullptr ? 0 : row_starts[bid]; + const auto length = lengths[bid]; + const auto dst_indices_entry = topk_indices_ragged + bid * TopK; + const auto score = input + bid * input_stride; + const auto offset = topk_indices_offset[bid]; + + if (length <= TopK) { + return naive_topk_transform_ragged(score, length, dst_indices_entry, offset); + } else { + __shared__ int s_indices[TopK]; + fast_topk_cuda_tl(score, s_indices, row_start, length); + // copy src[s_indices] to dst, we manually unroll here + // (static_asserts removed because TopK != 2048 in this build; the + // manual unroll below is unreachable from bench_topk.py which only + // calls fast_topk_interface, not this transform variant.) + const auto idx_0 = tid; + const auto pos_0 = s_indices[idx_0]; + dst_indices_entry[idx_0] = pos_0 + offset; + const auto idx_1 = tid + kThreadsPerBlock; + const auto pos_1 = s_indices[idx_1]; + dst_indices_entry[idx_1] = pos_1 + offset; + } + } + + auto get_params( + const at::Tensor& score, + const at::Tensor& lengths, + std::optional row_starts_opt = std::nullopt, + std::optional indices_opt = std::nullopt) -> FastTopKParams { + const auto B = score.size(0); + TORCH_CHECK(score.dim() == 2 && score.stride(1) == 1); + if (row_starts_opt.has_value()) { + const auto& row_starts = row_starts_opt.value(); + TORCH_CHECK(row_starts.dim() == 1); + TORCH_CHECK(row_starts.size(0) == B); + } + TORCH_CHECK(lengths.dim() == 1 && lengths.is_contiguous()); + TORCH_CHECK(lengths.size(0) == B); + int32_t* indices_data_ptr = nullptr; + if (indices_opt.has_value()) { + const auto& indices = indices_opt.value(); + TORCH_CHECK(indices.dim() == 2 && indices.is_contiguous()); + TORCH_CHECK(indices.size(0) == B); + TORCH_CHECK(indices.size(1) == TopK); + indices_data_ptr = indices.data_ptr(); + } + + return FastTopKParams{ + .input = score.data_ptr(), + .row_starts = row_starts_opt.has_value() ? row_starts_opt->data_ptr() : nullptr, + .indices = indices_data_ptr, + .lengths = lengths.data_ptr(), + .input_stride = score.stride(0), + }; + } + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + #ifdef USE_ROCM + // hipify will turn cudaFuncSetAttribute -> hipFuncSetAttribute. On ROCm, + // hipFuncSetAttribute expects `const void*` and hipcc does not accept passing + // a function pointer directly, so cast explicitly. + return ::cudaFuncSetAttribute( + reinterpret_cast(f), ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #else + // CUDA: keep original behavior (no cast needed). + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + #endif + }(); + TORCH_CHECK(result == cudaSuccess, "set_up_kernel_once failed:", ::cudaGetErrorString(result)); + } + + } // namespace + + // The public interface functions below collide by name with identically + // named symbols in topk_sglang.cu. Wrap them in `sglang_ori` so both + // translation units can be linked into the same vortex_torch_C extension. + namespace sglang_ori { + + #ifndef CHECK_CUDA + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + #endif + + void fast_topk_interface( + const at::Tensor& score, at::Tensor& indices, const at::Tensor& lengths, std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(indices); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + CHECK_CUDA(lengths); + const auto params = get_params(score, lengths, row_starts_opt, indices); + const auto B = score.size(0); + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + setup_kernel_smem_once(); + topk_kernel<<>>(params); + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& dst_page_table, + const at::Tensor& src_page_table, + const at::Tensor& cu_seqlens_q, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(dst_page_table); + CHECK_CUDA(src_page_table); + CHECK_CUDA(cu_seqlens_q); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(dst_page_table.dim() == 2 && dst_page_table.is_contiguous()); + TORCH_CHECK(src_page_table.dim() == 2 && src_page_table.stride(1) == 1); + TORCH_CHECK(cu_seqlens_q.dim() == 1 && cu_seqlens_q.is_contiguous()); + const auto prefill_bs = cu_seqlens_q.size(0) - 1; + TORCH_CHECK(dst_page_table.size(0) == B); + TORCH_CHECK(dst_page_table.size(1) == TopK); + TORCH_CHECK(src_page_table.size(0) == prefill_bs); + TORCH_CHECK(prefill_bs <= B); // prefill_bs should be smaller than expanded bs + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + const auto src_stride = src_page_table.stride(0); + + // dispatch to decode or prefill + // extend and draft extend: row_starts_opt is not null, invokes the prefill kernel + // decode: row_starts_opt is null, invokes the decode kernel + // target verify: row_starts_opt is null, invokes the prefill kernel + const auto is_decode = !row_starts_opt.has_value() && prefill_bs == B; + if (is_decode) { + setup_kernel_smem_once(); + topk_transform_decode_kernel<<>>( + params, dst_page_table.data_ptr(), src_page_table.data_ptr(), src_stride); + } else { + setup_kernel_smem_once(); + topk_transform_prefill_kernel<<>>( + params, + dst_page_table.data_ptr(), + src_page_table.data_ptr(), + src_stride, + cu_seqlens_q.data_ptr(), + prefill_bs); + } + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + void fast_topk_transform_ragged_interface( + const at::Tensor& score, + const at::Tensor& lengths, + at::Tensor& topk_indices_ragged, + const at::Tensor& topk_indices_offset, + std::optional row_starts_opt) { + CHECK_CUDA(score); + CHECK_CUDA(lengths); + CHECK_CUDA(topk_indices_ragged); + CHECK_CUDA(topk_indices_offset); + if (row_starts_opt.has_value()) { + CHECK_CUDA(row_starts_opt.value()); + } + + const auto params = get_params(score, lengths, row_starts_opt); + const auto B = score.size(0); + TORCH_CHECK(topk_indices_ragged.dim() == 2 && topk_indices_ragged.is_contiguous()); + TORCH_CHECK(topk_indices_offset.dim() == 1); + + TORCH_CHECK(topk_indices_ragged.size(0) == B); + TORCH_CHECK(topk_indices_ragged.size(1) == TopK); + TORCH_CHECK(topk_indices_offset.size(0) == B); + + // launch kernel + const auto stream = at::cuda::getCurrentCUDAStream().stream(); + const auto grid = dim3{static_cast(B)}; + const auto block = dim3{kThreadsPerBlock}; + + setup_kernel_smem_once(); + topk_transform_prefill_ragged_kernel<<>>( + params, topk_indices_ragged.data_ptr(), topk_indices_offset.data_ptr()); + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, "topk kernel failed:", ::cudaGetErrorString(result)); + } + + } // namespace sglang_ori + +// ====================================================================== +// Thin vortex_torch_C adapter: accepts the same CSR-ish inputs as +// topk_output_sglang so bench_topk.py can treat the original SGLang kernel +// as an alternate baseline. The ori kernel has TopK baked in as a compile- +// time constant; this build sets it to 30 to match --topk-val 30. +// ====================================================================== +void topk_output_sglang_ori( + const at::Tensor& x, // [total_dense, 1, 1] or [total_dense], bf16/fp32 + const at::Tensor& dense_kv_indptr, // int32 [eff_bs + 1] (unused — synthetic bench rows are uniform) + at::Tensor& indices_out, // int32 [eff_bs, TopK] + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) +{ + TORCH_CHECK(x.is_cuda(), "x must be a CUDA tensor"); + TORCH_CHECK(dense_kv_indptr.is_cuda(), "dense_kv_indptr must be a CUDA tensor"); + TORCH_CHECK(indices_out.is_cuda(), "indices_out must be a CUDA tensor"); + TORCH_CHECK(indices_out.scalar_type() == at::ScalarType::Int, + "indices_out must be int32"); + TORCH_CHECK(topk_val == static_cast(30), + "topk_output_sglang_ori: this build of the ori kernel hard-codes TopK=30; " + "rebuild topk_sglang_ori.cu with a different TopK if you need another value. " + "Got topk_val=", topk_val); + TORCH_CHECK(indices_out.dim() == 2 + && indices_out.size(0) == eff_batch_size + && indices_out.size(1) == 30, + "indices_out must be [eff_batch_size, 30]"); + + // ori kernel requires fp32 [B, stride] scores. Caller typically passes + // the bf16 score tensor; we materialize an fp32 view once per call. + at::Tensor score_f32; + if (x.scalar_type() == at::ScalarType::Float) { + score_f32 = x.contiguous().view({eff_batch_size, max_num_pages}); + } else if (x.scalar_type() == at::ScalarType::BFloat16) { + score_f32 = x.to(at::kFloat).contiguous().view({eff_batch_size, max_num_pages}); + } else { + TORCH_CHECK(false, "topk_output_sglang_ori: unsupported dtype ", x.scalar_type()); + } + + auto opts_i32 = at::TensorOptions().dtype(at::kInt).device(x.device()); + const int32_t usable_len = + static_cast(max_num_pages - reserved_bos - reserved_eos); + at::Tensor lengths = at::full({eff_batch_size}, usable_len, opts_i32); + at::Tensor row_starts = at::full({eff_batch_size}, + static_cast(reserved_bos), opts_i32); + + sglang_ori::fast_topk_interface( + score_f32, indices_out, lengths, + std::optional(row_starts)); +} \ No newline at end of file diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu index adba2d03..7fe99814 100644 --- a/csrc/topk_sglang_profile.cu +++ b/csrc/topk_sglang_profile.cu @@ -89,6 +89,16 @@ __device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); } +// Mirror of convert_to_uint8_dense in topk_sglang.cu so that the +// profile kernel (topk_profile_histogram / topk_profile_counters) +// reports accurate thr_bin / thr_size / abv_bins / pg/bin for +// MAPPING_DENSE_MANT. Keep in sync with the production kernel. +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + template __device__ __forceinline__ float vortex_to_float(T x); template <> @@ -164,6 +174,11 @@ __device__ void fast_topk_profile( const int tx = threadIdx.x; + // Mirror of the production kernel: MAPPING_DENSE_MANT bypasses + // apply_transform and uses a mantissa-heavy fp32 bit slice for the + // Stage-1 bucket. + const bool use_dense_bucket = (mapping.mode == MAPPING_DENSE_MANT); + if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; __syncthreads(); @@ -178,7 +193,13 @@ __device__ void fast_topk_profile( for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const float raw = vortex_to_float(input[idx + row_start]); - const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); // fmaxf(x, pivot) + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } ::atomicAdd(&p_histogram[bin], 1); } __syncthreads(); @@ -221,8 +242,13 @@ __device__ void fast_topk_profile( if (topk == 0) { for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const float raw = vortex_to_float(input[idx + row_start]); - const auto bin = static_cast( - compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } if (bin > threshold_bin_0) { const auto pos = ::atomicAdd(&p_counter, 1); index[pos] = idx; @@ -240,11 +266,13 @@ __device__ void fast_topk_profile( if (tx < RADIX + 1) p_histogram[tx] = 0; __syncthreads(); + const int sub_bin_offset_start = use_dense_bucket ? 8 : 24; for (int idx = tx; idx < length; idx += BLOCK_SIZE) { const float raw = vortex_to_float(input[idx + row_start]); const float remapped = apply_transform(raw, mapping); - const auto bin = static_cast( - compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + const auto bin = use_dense_bucket + ? static_cast(convert_to_uint8_dense(remapped)) + : static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); if (bin > threshold_bin_0) { const auto pos = ::atomicAdd(&p_counter, 1); index[pos] = idx; @@ -253,7 +281,7 @@ __device__ void fast_topk_profile( if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { p_input_idx[0][pos] = idx; const auto b32 = convert_to_uint32(remapped); - const auto sub_bin = (b32 >> 24) & 0xFF; + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; ::atomicAdd(&p_histogram[sub_bin], 1); } } @@ -265,10 +293,15 @@ __device__ void fast_topk_profile( } } - // Stage 2 refinement (4 rounds max). Default rounds=4, overwritten on exit. - if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = 4; + // Stage 2 refinement. Standard modes run up to 4 rounds (offsets + // 24/16/8/0); MAPPING_DENSE_MANT runs up to 2 rounds (offsets 8/0) + // because Stage 1 already consumed bits [23:16] of the fp32 key. + const int stage2_offset_start = use_dense_bucket ? 8 : 24; + const int stage2_max_rounds = use_dense_bucket ? 2 : 4; + if (tx == 0 && counters) counters[COUNTER_REFINE_ROUNDS] = stage2_max_rounds; #pragma unroll 4 for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; __shared__ int p_last_remain; const auto r_idx = round % 2; const auto _raw_num_input = p_num_input[r_idx]; @@ -290,7 +323,7 @@ __device__ void fast_topk_profile( const auto idx = p_input_idx[r_idx][i]; const float raw = vortex_to_float(input[idx + row_start]); const float remapped = apply_transform(raw, mapping); - const auto offset = 24 - round * 8; + const auto offset = stage2_offset_start - round * 8; const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&p_counter, 1); @@ -308,13 +341,13 @@ __device__ void fast_topk_profile( const auto idx = p_input_idx[r_idx][i]; const float raw = vortex_to_float(input[idx + row_start]); const float remapped = apply_transform(raw, mapping); - const auto offset = 24 - round * 8; + const auto offset = stage2_offset_start - round * 8; const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; if (bin > threshold_bin) { const auto pos = ::atomicAdd(&p_counter, 1); index[pos] = idx; } else if (bin == threshold_bin) { - if (round == 3) { + if (round == stage2_max_rounds - 1) { const auto pos = ::atomicAdd(&p_last_remain, -1); if (pos > 0) { index[target_k - pos] = idx; @@ -413,11 +446,18 @@ void TopKProfileHistogram_Kernel( if (tx < RADIX) s_histogram[tx] = 0; __syncthreads(); + const bool use_dense_bucket = (mapping.mode == MAPPING_DENSE_MANT); if (nblk > 0) { const ScoreT* __restrict__ score_blk = score + start; for (int i = tx; i < nblk; i += BLOCK_SIZE) { const float raw = vortex_to_float(score_blk[i]); - const auto bin = compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles); + int bin; + if (use_dense_bucket) { + const float clamped = apply_transform(raw, mapping); + bin = static_cast(convert_to_uint8_dense(clamped)); + } else { + bin = static_cast(compute_stage1_bin(raw, mapping, s_mapping_lut, s_mapping_quantiles)); + } ::atomicAdd(&s_histogram[bin], 1); } } diff --git a/examples/remap_function_bench_topk2028.sh b/examples/remap_function_bench_topk2028.sh new file mode 100755 index 00000000..6d95b59d --- /dev/null +++ b/examples/remap_function_bench_topk2028.sh @@ -0,0 +1,252 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark +# +# Compares four kernel configurations for TopK page selection: +# 1. baseline — unmapped topk (topk_output_sglang) +# 2. fused remap + topk — topk_output_sglang_fused +# 3. remap only — topk_remap_only (standalone kernel) +# 4. unmapped topk on remapped — topk_output_sglang on the output +# buffer of step 3 +# +# Per configuration the script also reports the threshold-bin +# position, the threshold-bin size, and how many values are +# selected from the threshold bin (derived from +# topk_profile_counters — collected after all timing measurements, +# never interleaved with latency measurements). +# +# Pipeline: +# 1. Calibrate — run `calibrate_topk.py` on the chosen model to +# collect the REAL per-segment topk distribution +# (raw_histograms.npy). Skippable via +# --real-histograms /path/to/raw_histograms.npy. +# 2. Autotune — run `autotune_topk_mapping.py` on those real +# histograms and pick the per-mode hyperparameter +# with the LOWEST measured topk kernel latency. +# 3. Remap bench— run `bench_topk.py --remap-bench` with the +# autotune-selected per-mode hyperparameters. +# +# Argument layout mirrors run_distribution_analysis_new.sh. +# +# Usage: +# # Default (Qwen/Qwen3-1.7B, block_size=16): +# bash remap_function_bench.sh --gpu 5 +# +# # Larger model + larger page/block size: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --block-size 32 \ +# --seq-len 16384 --topk-val 512 \ +# --modes "0 3 6 7" +# +# # Reuse an existing calibration: +# bash remap_function_bench.sh --gpu 0 \ +# --model-name Qwen/Qwen3-8B \ +# --real-histograms /path/to/calibration/raw_histograms.npy +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=5 +MODEL_NAME="Qwen/Qwen3-8B" +TOPK_VAL=2048 +MEM=0.7 +# Cap KV / VTX sparse prefill buffer sizing during Step 1 (see calibrate_topk.py --help). +MAX_TOTAL_TOKENS=64768 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +BLOCK_SIZE=8 +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their +# mapping happens inside compute_stage1_bin, not apply_transform, so +# split-phase timing isn't meaningful for them. +MAPPING_MODES="0 3 6 7 8 9 10 11 13 15 16 17 18 19 20" +# Fallback hparam used only if autotune is explicitly skipped. +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Empty by default — Step 1 will calibrate on the selected model. +# Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. +# REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms.npy" +#REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms_qwen3-4B.npy" +REAL_HISTOGRAMS="" +SKIP_AUTOTUNE=0 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# Qwen3-1.7B does not use DeepGEMM (no FP8/MoE path). +# Disable its JIT to silence "NVCC Compiler not found ... use NVRTC" on Blackwell. +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +# If DeepGEMM JIT is ever re-enabled, make sure it can find nvcc. +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Validate seq_len: need pages/seg > topk_val (3 reserved pages) +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +echo "============================================================" +echo "Remap Function Benchmark" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " GPU: ${GPU_ID}" +echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate — collect real-distribution topk histograms ── +# calibrate_topk.py runs the model end-to-end with histogram profiling +# enabled and writes per-segment raw_histograms.npy. The histograms are +# aggregated over every layer and every decode/prefill step so the +# autotune in Step 2 sees the true attention-score distribution. +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${RUN_DIR}/calibration" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" + echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" +fi + +# Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so +# lut.npy / quantiles.npy produced by calibration are no longer consumed. + +# ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── +# For every (mode, hparam) combo in the sweep grid, the autotune runs the +# fused remap+topk kernel on the real histogram and measures end-to-end +# kernel latency with CUDA events. The per-mode hparam with the lowest +# measured topk kernel latency wins. +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap benchmark (baseline / fused / remap / split) ── +echo "" +echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" +REMAP_JSON="${RUN_DIR}/remap_bench.json" +BENCH_EXTRA=() +[ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step3_remap_bench.log" +echo ">>> Step 3: Done. Remap bench saved to ${REMAP_JSON}" + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Remap Function Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " All outputs in: ${RUN_DIR}/" +echo " calibration/raw_histograms.npy — real topk distribution (per layer)" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench.json — per-config remap/topk/fused/baseline latencies" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/remap_function_bench.sh b/examples/remap_function_bench_topk30.sh similarity index 84% rename from examples/remap_function_bench.sh rename to examples/remap_function_bench_topk30.sh index 7d56d57a..3cb52e25 100755 --- a/examples/remap_function_bench.sh +++ b/examples/remap_function_bench_topk30.sh @@ -43,6 +43,8 @@ # bash remap_function_bench.sh --gpu 0 \ # --model-name Qwen/Qwen3-8B \ # --real-histograms /path/to/calibration/raw_histograms.npy +# # Tight GPU: lower calibration KV cap (default 1048576): +# bash remap_function_bench_topk30.sh --gpu 0 --max-total-tokens 524288 # ============================================================ set -euo pipefail @@ -50,27 +52,29 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=4 +GPU_ID=5 MODEL_NAME="Qwen/Qwen3-1.7B" -TOPK_VAL=2048 +TOPK_VAL=30 MEM=0.7 +MAX_TOTAL_TOKENS=1048576 ALGO="block_sparse_attention" SAMPLE_STRIDE=1 -SEQ_LEN=65536 +SEQ_LEN=32768 BLOCK_SIZE=16 -BATCH_SIZE=4 +BATCH_SIZE=1 NUM_KV_HEADS=8 DISTRIBUTIONS="normal bucket_uniform" -# Modes 1 (LUT_CDF) and 2 (Quantile) are evaluated only if calibration -# produces lut.npy / quantiles.npy. The shell script detects that below. -MAPPING_MODES="0 1 2 3 6 7 8 9 10 11 13" +# Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their +# mapping happens inside compute_stage1_bin, not apply_transform, so +# split-phase timing isn't meaningful for them. +MAPPING_MODES="0 3 6 7 8 9 10 11 13 15 16 17 18 19 20" # Fallback hparam used only if autotune is explicitly skipped. MAPPING_HPARAM=0.5 REPEAT=100 WARMUP=20 # Empty by default — Step 1 will calibrate on the selected model. # Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. -REAL_HISTOGRAMS="" +REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms.npy" SKIP_AUTOTUNE=0 # ── Parse arguments ─────────────────────────────────────────── @@ -79,6 +83,7 @@ while [[ $# -gt 0 ]]; do --model-name) MODEL_NAME="$2"; shift 2 ;; --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; @@ -98,6 +103,22 @@ while [[ $# -gt 0 ]]; do done export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" + +# Qwen3-1.7B does not use DeepGEMM (no FP8/MoE path). +# Disable its JIT to silence "NVCC Compiler not found ... use NVRTC" on Blackwell. +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +# If DeepGEMM JIT is ever re-enabled, make sure it can find nvcc. +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi # Validate seq_len: need pages/seg > topk_val (3 reserved pages) MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) @@ -126,6 +147,7 @@ echo " KV heads: ${NUM_KV_HEADS}" echo " Distributions: ${DISTRIBUTIONS}" echo " Mapping modes: ${MAPPING_MODES}" echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" echo " GPU: ${GPU_ID}" echo " Sample stride: ${SAMPLE_STRIDE}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" @@ -149,7 +171,9 @@ else python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${ALGO}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" @@ -157,14 +181,8 @@ else echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" fi -# Calibration may have produced lut.npy / quantiles.npy for modes 1 and 2. -CALIB_DIR="$(dirname "${REAL_HIST_PATH}")" -LUT_PATH="" -Q_PATH="" -[ -f "${CALIB_DIR}/lut.npy" ] && LUT_PATH="${CALIB_DIR}/lut.npy" -[ -f "${CALIB_DIR}/quantiles.npy" ] && Q_PATH="${CALIB_DIR}/quantiles.npy" -[ -n "${LUT_PATH}" ] && echo " Calibration LUT: ${LUT_PATH}" -[ -n "${Q_PATH}" ] && echo " Calibration quantile: ${Q_PATH}" +# Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so +# lut.npy / quantiles.npy produced by calibration are no longer consumed. # ── Step 2: Auto-tune hyperparameters by profiled fused-topk latency ── # For every (mode, hparam) combo in the sweep grid, the autotune runs the @@ -179,9 +197,6 @@ if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then else echo "" echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" - AUTOTUNE_EXTRA=() - [ -n "${LUT_PATH}" ] && AUTOTUNE_EXTRA+=(--lut-path "${LUT_PATH}") - [ -n "${Q_PATH}" ] && AUTOTUNE_EXTRA+=(--quantiles-path "${Q_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ --batch-size "${BATCH_SIZE}" \ --num-kv-heads "${NUM_KV_HEADS}" \ @@ -192,7 +207,6 @@ else --warmup "${WARMUP}" \ --repeat "${REPEAT}" \ --collect-stats \ - "${AUTOTUNE_EXTRA[@]}" \ --output-json "${AUTOTUNE_JSON}" \ 2>&1 | tee "${RUN_DIR}/step2_autotune.log" echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" @@ -204,8 +218,7 @@ echo "" echo ">>> Step 3: Timing remap / topk / fused / baseline with autotuned hparams" REMAP_JSON="${RUN_DIR}/remap_bench.json" BENCH_EXTRA=() -[ -n "${LUT_PATH}" ] && BENCH_EXTRA+=(--lut-path "${LUT_PATH}") -[ -n "${Q_PATH}" ] && BENCH_EXTRA+=(--quantiles-path "${Q_PATH}") +[ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --remap-bench \ --batch-sizes "${BATCH_SIZE}" \ diff --git a/examples/run_distribution_analysis.sh b/examples/run_distribution_analysis.sh index 36d4cd4b..25150153 100755 --- a/examples/run_distribution_analysis.sh +++ b/examples/run_distribution_analysis.sh @@ -22,6 +22,7 @@ # --real-histograms /path/to/calibration_dir/raw_histograms.npy # bash run_distribution_analysis.sh --gpu 5 --block-size 16 # bash run_distribution_analysis.sh --watchdog-timeout 0 # disable calibrate watchdog (fork) +# bash run_distribution_analysis.sh --max-total-tokens 1048576 # cap KV / VTX buffers during calibrate # Models (default: 1.7B + 4B). Override with repeated --model-name: # bash run_distribution_analysis.sh --model-name Qwen/Qwen3-1.7B --model-name Qwen/Qwen3-4B # ============================================================ @@ -52,6 +53,7 @@ MODEL_NAMES=( "Qwen/Qwen3-1.7B" "Qwen/Qwen3-4B" ) MODEL_NAMES_USER_SET=0 TOPK_VAL=30 MEM=0.7 +MAX_TOTAL_TOKENS=1048576 ALGO="block_sparse_attention" RADIX_BITS=8 SAMPLE_STRIDE=1 @@ -76,6 +78,7 @@ while [[ $# -gt 0 ]]; do ;; --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; @@ -117,6 +120,7 @@ echo " Block size: ${BLOCK_SIZE} (--page-size in benchmarks)" echo " GPU: ${GPU_ID}" echo " Radix bits: ${RADIX_BITS} ($(( 1 << RADIX_BITS )) bins)" echo " Sample stride: ${SAMPLE_STRIDE}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" if [ "${HAS_WATCHDOG_TIMEOUT}" -eq 1 ]; then echo " Watchdog (cal): ${WATCHDOG_TIMEOUT}s (0 = off, vortex SGLang fork)" else @@ -154,6 +158,7 @@ for MODEL_NAME in "${MODEL_NAMES[@]}"; do --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${ALGO}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ diff --git a/examples/run_distribution_analysis_new.sh b/examples/run_distribution_analysis_new.sh index ec726656..38438bde 100755 --- a/examples/run_distribution_analysis_new.sh +++ b/examples/run_distribution_analysis_new.sh @@ -26,6 +26,7 @@ # --model-name Qwen/Qwen3-8B --block-size 32 # bash run_distribution_analysis_new.sh --gpu 5 \ # --real-histograms /path/to/raw_histograms.npy +# bash run_distribution_analysis_new.sh --gpu 5 --max-total-tokens 524288 # ============================================================ set -euo pipefail @@ -37,6 +38,7 @@ GPU_ID=4 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=2048 MEM=0.7 +MAX_TOTAL_TOKENS=1048576 ALGO="block_sparse_attention" SEQ_LEN=65536 BLOCK_SIZE=16 @@ -55,6 +57,7 @@ while [[ $# -gt 0 ]]; do --model-name) MODEL_NAME="$2"; shift 2 ;; --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; @@ -97,6 +100,7 @@ echo " Batch size: ${BATCH_SIZE}" echo " KV heads: ${NUM_KV_HEADS}" echo " Distributions: ${DISTRIBUTIONS}" echo " Mapping modes: ${MAPPING_MODES}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" echo " GPU: ${GPU_ID}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" echo " Output: ${RUN_DIR}" @@ -116,6 +120,7 @@ else --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${ALGO}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ diff --git a/examples/run_topk_benchmark.sh b/examples/run_topk_benchmark.sh index 33c6e40d..f3eabff9 100755 --- a/examples/run_topk_benchmark.sh +++ b/examples/run_topk_benchmark.sh @@ -21,6 +21,7 @@ # bash run_topk_benchmark.sh --gpu 0 # bash run_topk_benchmark.sh --gpu 0 --model-name Qwen/Qwen3-8B \ # --block-size 32 --topk-val 512 +# bash run_topk_benchmark.sh --gpu 0 --max-total-tokens 1048576 # ============================================================ set -euo pipefail @@ -33,6 +34,7 @@ MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 TRIALS=8 MEM=0.7 +MAX_TOTAL_TOKENS=1048576 ALGO="block_sparse_attention" BLOCK_SIZE=16 BATCH_SIZE=4 @@ -50,6 +52,7 @@ while [[ $# -gt 0 ]]; do --topk-val) TOPK_VAL="$2"; shift 2 ;; --trials) TRIALS="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --benchmark) BENCHMARKS="$2"; shift 2 ;; @@ -84,6 +87,7 @@ echo " Seq len: ${SEQ_LEN}" echo " Batch size: ${BATCH_SIZE}" echo " KV heads: ${NUM_KV_HEADS}" echo " Trials: ${TRIALS}" +echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" echo " GPU: ${GPU_ID}" echo " Output: ${RUN_DIR}" echo "============================================================" @@ -101,6 +105,7 @@ else --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${ALGO}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ diff --git a/examples/verify_algo_topk_mapping.sh b/examples/verify_algo_topk_mapping.sh index 711a0f77..f361e594 100644 --- a/examples/verify_algo_topk_mapping.sh +++ b/examples/verify_algo_topk_mapping.sh @@ -26,6 +26,7 @@ BLOCK_SIZE=16 BATCH_SIZE=4 NUM_KV_HEADS=2 SEQ_LEN=32768 +MAX_TOTAL_TOKENS=1048576 REAL_HISTOGRAMS="" SKIP_AUTOTUNE=0 @@ -40,6 +41,7 @@ while [[ $# -gt 0 ]]; do --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; --seq-len) SEQ_LEN="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; *) echo "Unknown option: $1"; exit 1 ;; esac @@ -80,12 +82,14 @@ done # ============================================================ if [ -z "${REAL_HISTOGRAMS}" ]; then CALIBRATION_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" + echo ">>> Max total tokens (KV / VTX cap): ${MAX_TOTAL_TOKENS}" for algo in "${sparse_algos[@]}"; do echo ">>> Calibrating ${MODEL_NAME} for ${algo}..." python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem 0.7 \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${algo}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CALIBRATION_DIR}" \ diff --git a/examples/verify_algo_topk_mapping_new.sh b/examples/verify_algo_topk_mapping_new.sh index 9116b722..2cdc5265 100644 --- a/examples/verify_algo_topk_mapping_new.sh +++ b/examples/verify_algo_topk_mapping_new.sh @@ -28,6 +28,7 @@ BLOCK_SIZE=16 BATCH_SIZE=4 NUM_KV_HEADS=2 SEQ_LEN=32768 +MAX_TOTAL_TOKENS=1048576 REAL_HISTOGRAMS="" SKIP_AUTOTUNE=0 @@ -42,6 +43,7 @@ while [[ $# -gt 0 ]]; do --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; --seq-len) SEQ_LEN="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; *) echo "Unknown option: $1"; exit 1 ;; esac @@ -63,6 +65,7 @@ TIMESTAMP=$(date +%Y%m%d_%H%M%S) if [ -z "${REAL_HISTOGRAMS}" ]; then echo "============================================================" echo "Step 0: Calibrating ${MODEL_NAME} for real-distribution histograms" + echo " Max total tokens (KV / VTX cap): ${MAX_TOTAL_TOKENS}" echo "============================================================" CAL_DIR="${RESULTS_DIR}/calibration_${TIMESTAMP}" mkdir -p "${CAL_DIR}" @@ -70,6 +73,7 @@ if [ -z "${REAL_HISTOGRAMS}" ]; then --model-name "${MODEL_NAME}" \ --topk-val "${TOPK_VAL}" \ --mem 0.7 \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ --vortex-module-name "${sparse_algos[0]}" \ --page-size "${BLOCK_SIZE}" \ --output-dir "${CAL_DIR}" \ diff --git a/setup.py b/setup.py index 0fc46ad8..c9731815 100644 --- a/setup.py +++ b/setup.py @@ -19,6 +19,7 @@ 'csrc/topk.cu', 'csrc/topk_sglang.cu', 'csrc/topk_sglang_profile.cu', + 'csrc/topk_sglang_ori.cu', ], include_dirs=['csrc'], extra_compile_args={ From fe0b8e2881ba39bb93c36ab043a27af59d078ebc Mon Sep 17 00:00:00 2001 From: UED Date: Wed, 15 Apr 2026 00:50:12 -0400 Subject: [PATCH 21/24] Enhance TopK benchmarking and calibration scripts - Added parameter to for per-head benchmarking. - Introduced function to aggregate per-head configurations. - Updated to include metrics for per-head configurations. - Added disk space check in to ensure sufficient space for model downloads. - Implemented regression guard against saving degenerate histograms in calibration. - Modified example scripts for improved calibration and benchmarking workflows. --- benchmarks/bench_topk.py | 180 ++++++++++++++++++++-- benchmarks/calibrate_topk.py | 57 ++++++- csrc/topk_sglang.cu | 52 +++++-- csrc/utils_sglang.cu | 30 ++-- examples/remap_function_bench_topk2028.sh | 56 +++++-- examples/remap_function_bench_topk30.sh | 30 +++- third_party/sglang | 2 +- vortex_torch/indexer/utils_sglang.py | 2 +- 8 files changed, 353 insertions(+), 56 deletions(-) diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index f0860c94..3653c55a 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -397,8 +397,15 @@ def _resolve_hparam(args, mode: int) -> float: def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, - distribution, modes: List[int]) -> dict: - """Time baseline, fused, and split-phase for each mode at one config.""" + distribution, modes: List[int], + head_label: str = "all") -> dict: + """Time baseline, fused, and split-phase for each mode at one config. + + `head_label` is metadata: ``"all"`` for the aggregated table (default), + or a stringified head index ``"0".."N-1"`` for per-head benches. The + caller is responsible for setting ``args._real_histogram`` to the + head-sliced sub-histogram before invoking this function in per-head mode. + """ real_hist = getattr(args, "_real_histogram", None) if distribution == "real" else None inputs = make_topk_inputs( batch_size=batch_size, @@ -489,6 +496,7 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_val": topk_val, "distribution": distribution, "pages_per_seg": pages_per_seg, + "head": head_label, "baseline_ms": baseline["mean_ms"], "naive_ms": naive_ms, "sglang_ori_ms": sglang_ori_ms, @@ -644,17 +652,27 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, return config +# Stage-2 working-set cap, matches SMEM_INPUT_SIZE in fast_topk_clean_fused +# (32 KB dynamic smem / 2 ping-pong buffers / 4 bytes per int = 4096). +_STAGE2_SMEM_CAP = 4096 + + def _print_remap_table(results: List[dict]) -> None: + # The printed table only carries metrics that participate in the + # fused-kernel cost model. All purely-informational columns + # (thr_bin / sel_thr / abv_bins / pg/bin) were dropped — they're + # still in the JSON for downstream tools, just not in the table. header = ( f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " - f"{'fused_ms':>9s} {'base_ms':>9s} {'thr_bin':>7s} {'thr_size':>8s} " - f"{'sel_thr':>7s} {'abv_bins':>8s} {'pg/bin':>7s}" + f"{'fused_ms':>9s} {'base_ms':>9s} " + f"{'s1p2_load':>9s} {'eff_thr':>7s} {'rounds':>6s} {'s2_work':>8s}" ) for cfg in results: banner = ( f"\n[batch={cfg['batch_size']} heads={cfg['num_kv_heads']} " f"seq_len={cfg['seq_len']} topk={cfg['topk_val']} " - f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']}]" + f"dist={cfg['distribution']} pages_per_seg={cfg['pages_per_seg']} " + f"head={cfg.get('head', 'all')}]" ) print(banner) extra_notes = [] @@ -666,6 +684,12 @@ def _print_remap_table(results: List[dict]) -> None: if extra_notes: notes_str = " | " + " | ".join(extra_notes) print(f" Baseline: topk_sglang.cu (CUB two-stage){notes_str}") + print( + f" s1p2_load = thr_size (uncapped global re-reads in Stage-1 pass 2) " + f"eff_thr = min(thr_size, {_STAGE2_SMEM_CAP}) " + f"rounds = stage-2 passes (1..4) " + f"s2_work = rounds * eff_thr" + ) print(header) print("-" * len(header)) base_ms = cfg["baseline_ms"] @@ -681,6 +705,11 @@ def _print_remap_table(results: List[dict]) -> None: def _fmt(v): return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" fused_str = _fmt(row.get("fused_ms")) + thr_size = row.get("threshold_bin_size_mean", 0.0) + rounds = row.get("refine_rounds_mean", 0.0) + eff_thr = min(thr_size, float(_STAGE2_SMEM_CAP)) + s2_work = rounds * eff_thr + s1p2_load = thr_size # alias: same number, named for the cost-model role print( f"{label:<14s} " f"{_fmt(row['remap_ms'])} " @@ -688,14 +717,77 @@ def _fmt(v): f"{_fmt(row['split_total_ms'])} " f"{fused_str} " f"{base_ms:9.4f} " - f"{row['threshold_bin_mean']:7.1f} " - f"{row['threshold_bin_size_mean']:8.1f} " - f"{row['selected_from_thr_mean']:7.1f} " - f"{row.get('above_bins_mean', 0.0):8.1f} " - f"{row.get('pages_per_above_bin_mean', 0.0):7.1f}" + f"{s1p2_load:9.0f} " + f"{eff_thr:7.0f} " + f"{rounds:6.2f} " + f"{s2_work:8.0f}" ) +def _combine_per_head_cfgs(per_head_cfgs: List[dict]) -> dict: + """Combine a list of per-head cfg dicts (same shape, head='0','1',...) + into a single aggregated cfg tagged head='all', by averaging every + numeric field. This is used when --per-head-bench is on so the + aggregated row reflects the realistic per-head behaviour rather than + a separate kernel launch on an averaged histogram. + + Assumes every cfg has the same `modes` list in the same order — which + holds because all per-head sub-runs use identical (batch, heads, seq, + topk, page_size, reserved, mapping_modes) parameters and therefore + take the same code paths through `_remap_bench_one_config`. + """ + assert per_head_cfgs, "_combine_per_head_cfgs called with empty list" + base = per_head_cfgs[0] + n_modes = len(base["modes"]) + # Sanity: same shape. + for c in per_head_cfgs[1:]: + assert len(c["modes"]) == n_modes, ( + f"per-head cfgs disagree on mode count: {n_modes} vs {len(c['modes'])}" + ) + + def _mean_or_none(vals): + vs = [v for v in vals if v is not None] + return (sum(vs) / len(vs)) if vs else None + + combined: Dict = { + "batch_size": base["batch_size"], + "num_kv_heads": base["num_kv_heads"], + "seq_len": base["seq_len"], + "topk_val": base["topk_val"], + "distribution": base["distribution"], + "pages_per_seg": base["pages_per_seg"], + "head": "all", + "baseline_ms": _mean_or_none([c.get("baseline_ms") for c in per_head_cfgs]), + "naive_ms": _mean_or_none([c.get("naive_ms") for c in per_head_cfgs]), + "sglang_ori_ms": _mean_or_none([c.get("sglang_ori_ms") for c in per_head_cfgs]), + "modes": [], + } + + # Numeric fields per mode row that we average; non-numeric fields (mode, + # mode_name, power) are copied from the first cfg since they're identical + # across heads by construction. + NUMERIC_KEYS = ( + "remap_ms", "topk_after_remap_ms", "split_total_ms", "fused_ms", + "threshold_bin_mean", "threshold_bin_max", + "num_above_mean", + "threshold_bin_size_mean", "threshold_bin_size_max", + "selected_from_thr_mean", "selected_from_thr_max", + "refine_rounds_mean", + "above_bins_mean", "pages_per_above_bin_mean", + ) + for mi in range(n_modes): + sample = base["modes"][mi] + merged = { + "mode": sample["mode"], + "mode_name": sample["mode_name"], + "power": sample["power"], + } + for key in NUMERIC_KEYS: + merged[key] = _mean_or_none([c["modes"][mi].get(key) for c in per_head_cfgs]) + combined["modes"].append(merged) + return combined + + def _run_remap_bench(args) -> None: modes = [int(m) for m in args.mapping_modes] # Mode 0 is emitted as the "None" row from _remap_bench_one_config @@ -710,14 +802,68 @@ def _run_remap_bench(args) -> None: print(f"[remap-bench] 'real' distribution enabled " f"(histogram total count = {int(args._real_histogram.sum())})") + if getattr(args, "per_head_bench", False): + if getattr(args, "_real_histograms_raw", None) is None: + raise SystemExit( + "[bench-remap] --per-head-bench requires --real-histograms with a 2D raw file." + ) + if not args.num_kv_heads or any(h <= 0 for h in args.num_kv_heads): + raise SystemExit("[bench-remap] --per-head-bench requires --num-kv-heads > 0.") + # When the user passes multiple --num-kv-heads values we slice by the + # first one (the others are degenerate for per-head reporting since + # the histogram file has a fixed head count). + per_head_count = int(args.num_kv_heads[0]) + results = [] + # When --per-head-bench is on, each "real"-distribution aggregate is + # built by averaging the 8 per-head measurements (NOT by running an + # extra kernel on an averaged histogram). This grouping keeps the + # per-head cfgs that should fold into each (bs, heads, seq, topk) + # aggregate point. + per_head_groups: dict = {} + + # ---- Per-head tables (printed first) ---- + if getattr(args, "per_head_bench", False): + raw = args._real_histograms_raw + saved_agg = args._real_histogram + try: + for h in range(per_head_count): + # Slice rows belonging to head `h`. Rows are interleaved as + # row_idx % num_kv_heads = head_idx, so this strided slice + # collects all (call, batch, h) triples across the file. + args._real_histogram = raw[h::per_head_count].sum(axis=0) + for bs in args.batch_sizes: + for heads in args.num_kv_heads: + for seq_len in args.seq_lens: + for topk_val in args.topk_vals: + cfg = _remap_bench_one_config( + args, bs, heads, seq_len, topk_val, "real", modes, + head_label=str(h), + ) + results.append(cfg) + per_head_groups.setdefault( + (bs, heads, seq_len, topk_val), [] + ).append(cfg) + finally: + args._real_histogram = saved_agg + + # ---- Aggregated tables (printed last) ---- for bs in args.batch_sizes: for heads in args.num_kv_heads: for seq_len in args.seq_lens: for topk_val in args.topk_vals: for dist in distributions: + if dist == "real" and getattr(args, "per_head_bench", False): + cfgs = per_head_groups.get((bs, heads, seq_len, topk_val), []) + if cfgs: + # Combine the per-head cfgs into a single + # aggregated row — no extra kernel launch. + cfg = _combine_per_head_cfgs(cfgs) + results.append(cfg) + continue cfg = _remap_bench_one_config( args, bs, heads, seq_len, topk_val, dist, modes, + head_label="all", ) results.append(cfg) @@ -838,6 +984,13 @@ def main(): p.add_argument("--output-json", type=str, default=None) p.add_argument("--remap-bench", action="store_true", help="Run the split-phase remap/topk/fused/baseline benchmark.") + p.add_argument("--per-head-bench", action="store_true", + help="In addition to the aggregated 'real'-distribution table, also " + "run the remap-bench once per KV head: slice the calibrated " + "histogram into one sub-histogram per head (using " + "row_idx %% num_kv_heads = head_idx), bench each, and print one " + "table per head followed by the aggregated table. Requires " + "--real-histograms (with a 2D raw file) and --num-kv-heads.") args = p.parse_args() args._autotune_hparams = {} @@ -848,9 +1001,14 @@ def main(): print(f" mode {m:>2d} -> {v}") args._real_histogram = None + args._real_histograms_raw = None if args.real_histograms: - raw = np.load(args.real_histograms) + # mmap_mode='r' keeps the (potentially 20+ GB) raw file off-heap; we + # only materialise per-head sums when --per-head-bench is set. + raw = np.load(args.real_histograms, mmap_mode='r') args._real_histogram = raw.sum(axis=0) if raw.ndim > 1 else raw + if raw.ndim > 1: + args._real_histograms_raw = raw print(f"[real] loaded calibrated histogram from {args.real_histograms} " f"(shape={raw.shape} → [256] aggregate)") diff --git a/benchmarks/calibrate_topk.py b/benchmarks/calibrate_topk.py index 4914133f..f3343aaa 100644 --- a/benchmarks/calibrate_topk.py +++ b/benchmarks/calibrate_topk.py @@ -17,6 +17,7 @@ import argparse import json import os +import shutil import sys import numpy as np @@ -46,9 +47,17 @@ def main(): "Block-sparse profiling uses a small bytes/token estimate, so the auto " "budget can be huge on large GPUs; VTXGraphAttnBackend then allocates " "dense bf16 sparse_prefill K/V buffers proportional to this cap (~4 KiB per " - "token per buffer). For offline calibration, a few hundred K1M tokens " + "token per buffer). For offline calibration, a few hundred K to 1M tokens " "is usually enough.", ) + parser.add_argument( + "--min-free-disk-gb", + type=float, + default=20.0, + help="Abort if the filesystem for --output-dir (and HF cache, typically the same) " + "has less than this many GiB free. First-time model downloads need many GiB. " + "Set to 0 to disable.", + ) parser.add_argument("--kv-cache-dtype", type=str, default="auto") parser.add_argument("--topk-type", type=str, default="sglang") parser.add_argument("--num-prompts", type=int, default=16, @@ -65,6 +74,30 @@ def main(): ) args = parser.parse_args() + # Classic HTTP downloads avoid XET chunk reconstruction ("Background writer channel + # closed") that often surfaces when the disk is full or nearly full. + if "HF_HUB_DISABLE_XET" not in os.environ: + os.environ["HF_HUB_DISABLE_XET"] = "1" + + if args.min_free_disk_gb > 0: + check_path = os.path.abspath(args.output_dir) + while check_path and not os.path.isdir(check_path): + parent = os.path.dirname(check_path) + if parent == check_path: + check_path = os.getcwd() + break + check_path = parent + usage = shutil.disk_usage(check_path) + free_gb = usage.free / (1024.0**3) + if free_gb < args.min_free_disk_gb: + raise SystemExit( + f"[calibrate] ERROR: Only {free_gb:.1f} GiB free on filesystem containing " + f"{args.output_dir!r} (checked from {check_path!r}). " + f"Need at least ~{args.min_free_disk_gb} GiB for Hugging Face weights, hub cache, " + f"and logs. Free disk space or point HF_HOME at a larger disk. " + f"To skip this check: --min-free-disk-gb 0" + ) + # Lazy imports to avoid slow startup when just checking --help import sglang as sgl import torch @@ -135,6 +168,28 @@ def main(): all_hists = torch.cat(histograms, dim=0).numpy() # [total_samples, 256] print(f"[calibrate] Total histogram samples: {all_hists.shape[0]}") + # Regression guard: refuse to save a collapsed histogram. A healthy + # calibration touches tens to hundreds of bins; if almost everything lands + # in a single bin, the scoring pipeline silently produced zero scores + # (see the Sgl_Decode_Plan_Workload_Kernel `w > topk_val` bug fixed in + # csrc/utils_sglang.cu). Saving 20+ GB of all-zeros wastes disk and poisons + # downstream benches, so fail loudly here. + _pooled = all_hists.sum(axis=0).astype(np.float64) + _total = float(_pooled.sum()) + if _total > 0: + _top_frac = float(_pooled.max()) / _total + _nz_bins = int((_pooled > 0).sum()) + if _top_frac > 0.95 or _nz_bins < 5: + llm.shutdown() + raise SystemExit( + f"[calibrate] ERROR: degenerate histogram — top bin holds " + f"{_top_frac:.2%} of mass, only {_nz_bins}/256 bins nonzero. " + f"The scoring pipeline is likely not running (check " + f"winfo_num_workloads in plan_decode, or `w > topk_val` in " + f"Sgl_Decode_Plan_Workload_Kernel). Refusing to save to avoid " + f"writing a useless multi-GB file." + ) + # --- Generate LUT (mode 1) --- # Aggregate histogram across all samples avg_histogram = all_hists.mean(axis=0) diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index fa9c8250..73366dfa 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -50,6 +50,15 @@ constexpr size_t kSmem = 48 * 1024; // bytes constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32KB (bytes) #endif +// Fused-kernel dynamic smem ceiling. The fused kernel uses `kSmem` bytes for +// f_input_idx (2 × SMEM_INPUT_SIZE ints) AND an extra `max_num_pages` bytes +// for s_bins (one uint8_t per page). Ceiling of 96 KB covers max_num_pages up +// to 65536 and fits the opt-in dynamic-smem limits on every target in +// setup.py (sm_86 ≥99KB, sm_89 100KB, sm_90 228KB, sm_100a/120 ≥100KB). +// Only `topk_output_sglang_fused` uses this ceiling; the other kernels keep +// kSmem as their dynamic-smem budget. +constexpr size_t kFusedSmemMax = 96 * 1024; + struct FastTopKParams { const float* __restrict__ input; // [B, input_stride] const int32_t* __restrict__ row_starts; // [B] @@ -670,16 +679,23 @@ __device__ void fast_topk_clean_fused( // Per-element Stage-1 bin cache. Pass 1 of Stage 1 writes one byte per // element; pass 2 reads it back so each element only pays a single - // apply_transform + global score read instead of two. Sized to the - // maximum `pages_per_seg` the bench drivers use (topk=2048 config has - // seq_len=32768 / page_size=8 = 4096 pages per segment; topk=30 has - // 2048). Shrinking from 8192 to 4096 freed 4 KB of static SMEM per - // block, which lifts occupancy from 5 → 6 blocks/SM on B200. - constexpr int kFusedMaxLen = 4096; - __shared__ uint8_t s_bins[kFusedMaxLen]; + // apply_transform + global score read instead of two. + // + // s_bins lives in DYNAMIC shared memory, placed immediately after the + // f_input_idx[2][SMEM_INPUT_SIZE] 2D array in the same extern __shared__ + // region. The host launch reserves `kSmem + max_num_pages` dynamic bytes + // (see `topk_output_sglang_fused`) so every block has `max_num_pages` + // bytes available past f_input_idx's 32 KB span. Per-block `length` + // (from dense_kv_indptr) is ≤ max_num_pages, so indexing stays in bounds. + // + // This layout keeps smem usage at kSmem + 4 KB for the existing + // pages_per_seg ≤ 4096 regimes (identical to the old 32 KB dynamic + + // 4 KB static) and only grows when the caller asks for a larger + // pages_per_seg — no occupancy regression on small configs. auto& f_histogram = f_histogram_buf[0]; extern __shared__ int f_input_idx[][SMEM_INPUT_SIZE]; + uint8_t* const s_bins = reinterpret_cast(&f_input_idx[2][0]); const int tx = threadIdx.x; @@ -1192,10 +1208,20 @@ void topk_output_sglang_fused( "topk_output_sglang_fused: topk_val (", topk_val, ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - // Caller contract: max_num_pages must be <= 4096, the static SMEM - // `s_bins` cache size inside the templated fused kernel. The bench - // drivers stay within this bound; no runtime check is emitted in - // the hot path. + // Dynamic-smem layout for the fused kernel: + // [ f_input_idx (2 × SMEM_INPUT_SIZE × sizeof(int) = kSmem bytes) + // s_bins (bins_bytes = align_up(max_num_pages, 16)) ] + // The per-launch smem request equals the total of both. It must fit + // under kFusedSmemMax, which setup_kernel_smem_once opted this kernel + // into via cudaFuncSetAttribute(MaxDynamicSharedMemorySize, ...). + const size_t bins_bytes = (static_cast(max_num_pages) + size_t(15)) & ~size_t(15); + const size_t smem_bytes = kSmem + bins_bytes; + TORCH_CHECK(smem_bytes <= kFusedSmemMax, + "topk_output_sglang_fused: max_num_pages (", max_num_pages, + ") exceeds the fused kernel's dynamic smem ceiling. " + "Requested smem=", smem_bytes, " bytes, ceiling=", kFusedSmemMax, + " bytes. Raise kFusedSmemMax (and verify GPU opt-in limits) or " + "reduce pages_per_seg."); // The `mapping_lut` / `mapping_quantiles` optional tensors are // retained in the pybind signature for API backward compatibility @@ -1226,8 +1252,8 @@ void topk_output_sglang_fused( // time per-call cost, negligible relative to the kernel runtime. #define VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MODE_VAL) \ do { \ - setup_kernel_smem_once, kSmem>(); \ - TopKOutput_Fused_Kernel<<>>( \ + setup_kernel_smem_once, kFusedSmemMax>(); \ + TopKOutput_Fused_Kernel<<>>( \ PTR_EXPR, \ dense_kv_indptr.data_ptr(), \ sparse_kv_indptr.data_ptr(), \ diff --git a/csrc/utils_sglang.cu b/csrc/utils_sglang.cu index 1420e9ec..a7ddf42f 100644 --- a/csrc/utils_sglang.cu +++ b/csrc/utils_sglang.cu @@ -82,16 +82,20 @@ const int page_reserved_eos #pragma unroll for (int i = 0; i < ITEM_PER_THREAD; ++i){ - int16_t w = ((tx_offset + i) < eff_batch_size) ? - (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] + int16_t w = ((tx_offset + i) < eff_batch_size) ? + (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] - page_reserved_bos - page_reserved_eos): 0; - - page_count[i] = (w > topk_val) ? w : 0; + + // See note in Sgl_Decode_Plan_Workload_Kernel: we used to skip slots + // where w ≤ topk_val, but downstream (GeMV / topK / histogram) has no + // matching skip, so it read uninitialised scores and silently + // produced all-zero results. Emit workloads for every slot with w > 0. + page_count[i] = (w > 0) ? w : 0; chunked_page_count_prefix_sum[i + 1] = int((page_count[i] + max_chunk_size - 1) / max_chunk_size); } BlockScanInt(temp.scan_int).InclusiveSum(chunked_page_count_prefix_sum, chunked_page_count_prefix_sum); - + if (tx == 1023){ *winfo_num_workload = chunked_page_count_prefix_sum[ITEM_PER_THREAD]; *winfo_chunk_size = max_chunk_size; @@ -218,16 +222,22 @@ const int page_reserved_eos #pragma unroll for (int i = 0; i < ITEM_PER_THREAD; ++i){ - int16_t w = ((tx_offset + i) < eff_batch_size) ? - (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] + int16_t w = ((tx_offset + i) < eff_batch_size) ? + (dense_kv_indptr[tx_offset+i+1] - dense_kv_indptr[tx_offset+i] - page_reserved_bos - page_reserved_eos): 0; - - page_count[i] = (w > topk_val) ? w : 0; + + // Previously: (w > topk_val) ? w : 0, which skipped scoring on slots + // where the dense page count is already ≤ topk_val. Downstream (GeMV, + // topK, histogram profiling) does NOT have a matching skip, so it + // would read uninitialised scores and silently return garbage (all + // zero). Emit workloads for every slot with w > 0 so scoring always + // runs; when w ≤ topk_val the topK degenerates to "select all w". + page_count[i] = (w > 0) ? w : 0; chunked_page_count_prefix_sum[i + 1] = int((page_count[i] + max_chunk_size - 1) / max_chunk_size); } BlockScanInt(temp.scan_int).InclusiveSum(chunked_page_count_prefix_sum, chunked_page_count_prefix_sum); - + if (tx == 1023){ *winfo_num_workloads = chunked_page_count_prefix_sum[ITEM_PER_THREAD]; *winfo_chunk_size = max_chunk_size; diff --git a/examples/remap_function_bench_topk2028.sh b/examples/remap_function_bench_topk2028.sh index 6d95b59d..26c529c4 100755 --- a/examples/remap_function_bench_topk2028.sh +++ b/examples/remap_function_bench_topk2028.sh @@ -50,33 +50,40 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=5 -MODEL_NAME="Qwen/Qwen3-8B" +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=2048 MEM=0.7 # Cap KV / VTX sparse prefill buffer sizing during Step 1 (see calibrate_topk.py --help). MAX_TOTAL_TOKENS=64768 +# Min free GiB on the output-dir filesystem before Step 1 (HF weights + cache + logs). +MIN_FREE_DISK_GB=22 ALGO="block_sparse_attention" SAMPLE_STRIDE=1 SEQ_LEN=32768 -BLOCK_SIZE=8 +BLOCK_SIZE=1 BATCH_SIZE=4 NUM_KV_HEADS=8 DISTRIBUTIONS="normal bucket_uniform" # Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their # mapping happens inside compute_stage1_bin, not apply_transform, so # split-phase timing isn't meaningful for them. -MAPPING_MODES="0 3 6 7 8 9 10 11 13 15 16 17 18 19 20" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19" # Fallback hparam used only if autotune is explicitly skipped. MAPPING_HPARAM=0.5 REPEAT=100 WARMUP=20 # Empty by default — Step 1 will calibrate on the selected model. # Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. -# REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms.npy" -#REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms_qwen3-4B.npy" -REAL_HISTOGRAMS="" +# REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms.npy" +#REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-4B.npy" +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy" SKIP_AUTOTUNE=0 +# Optional: pre-built autotune JSON to bypass Step 2 entirely. When set, +# Step 2 is skipped and Step 3 reads its per-mode hparams from this file +# instead. Useful for verification runs where we want to pin the exact +# (mode, hparam) pairs without re-running the latency sweep. +PINNED_AUTOTUNE_JSON="" # ── Parse arguments ─────────────────────────────────────────── while [[ $# -gt 0 ]]; do @@ -85,6 +92,7 @@ while [[ $# -gt 0 ]]; do --topk-val) TOPK_VAL="$2"; shift 2 ;; --mem) MEM="$2"; shift 2 ;; --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; --gpu) GPU_ID="$2"; shift 2 ;; --algo) ALGO="$2"; shift 2 ;; --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; @@ -99,6 +107,7 @@ while [[ $# -gt 0 ]]; do --repeat) REPEAT="$2"; shift 2 ;; --warmup) WARMUP="$2"; shift 2 ;; --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + --pinned-autotune-json) PINNED_AUTOTUNE_JSON="$2"; SKIP_AUTOTUNE=1; shift 2 ;; *) echo "Unknown option: $1"; exit 1 ;; esac done @@ -136,6 +145,18 @@ MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" +# Calibration artifacts live on /var/tmp (large disk), keyed by model. +# Example: /var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-8B.npy +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +# If no explicit --real-histograms and a cached file exists, reuse it. +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + echo "============================================================" echo "Remap Function Benchmark" echo " Model: ${MODEL_NAME}" @@ -149,6 +170,7 @@ echo " Distributions: ${DISTRIBUTIONS}" echo " Mapping modes: ${MAPPING_MODES}" echo " Fallback hparam: ${MAPPING_HPARAM} (used only when --skip-autotune)" echo " Max total tokens: ${MAX_TOTAL_TOKENS} (calibration KV / VTX buffer cap)" +echo " Min free disk: ${MIN_FREE_DISK_GB} GiB (Step 1 preflight; 0 = skip)" echo " GPU: ${GPU_ID}" echo " Sample stride: ${SAMPLE_STRIDE}" echo " Real histograms: ${REAL_HISTOGRAMS:-}" @@ -167,7 +189,7 @@ if [ -n "${REAL_HISTOGRAMS}" ]; then else echo "" echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" - CALIBRATION_DIR="${RUN_DIR}/calibration" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${CALIBRATION_DIR}" python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ @@ -175,11 +197,15 @@ else --page-size "${BLOCK_SIZE}" \ --mem "${MEM}" \ --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ --vortex-module-name "${ALGO}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" - REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" - echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" + # Promote raw_histograms.npy to the shared per-model cache path. + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" + echo ">>> Step 1: Staging dir (lut/quantiles/logs): ${CALIBRATION_DIR}" fi # Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so @@ -193,8 +219,13 @@ fi AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then echo "" - echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" - AUTOTUNE_ARGS="" + if [ -n "${PINNED_AUTOTUNE_JSON}" ]; then + echo ">>> Step 2: SKIPPED (pinned hparams from ${PINNED_AUTOTUNE_JSON})" + AUTOTUNE_ARGS="--autotune-json ${PINNED_AUTOTUNE_JSON}" + else + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" + fi else echo "" echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" @@ -222,6 +253,7 @@ BENCH_EXTRA=() [ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --remap-bench \ + --per-head-bench \ --batch-sizes "${BATCH_SIZE}" \ --num-kv-heads "${NUM_KV_HEADS}" \ --seq-lens "${SEQ_LEN}" \ diff --git a/examples/remap_function_bench_topk30.sh b/examples/remap_function_bench_topk30.sh index 3cb52e25..3843906c 100755 --- a/examples/remap_function_bench_topk30.sh +++ b/examples/remap_function_bench_topk30.sh @@ -52,7 +52,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=5 +GPU_ID=1 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=30 MEM=0.7 @@ -61,20 +61,20 @@ ALGO="block_sparse_attention" SAMPLE_STRIDE=1 SEQ_LEN=32768 BLOCK_SIZE=16 -BATCH_SIZE=1 +BATCH_SIZE=4 NUM_KV_HEADS=8 DISTRIBUTIONS="normal bucket_uniform" # Modes 1 (LUT_CDF) and 2 (Quantile) are no longer benchmarked — their # mapping happens inside compute_stage1_bin, not apply_transform, so # split-phase timing isn't meaningful for them. -MAPPING_MODES="0 3 6 7 8 9 10 11 13 15 16 17 18 19 20" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19 20" # Fallback hparam used only if autotune is explicitly skipped. MAPPING_HPARAM=0.5 REPEAT=100 WARMUP=20 # Empty by default — Step 1 will calibrate on the selected model. # Pass --real-histograms /path/to/raw_histograms.npy to skip calibration. -REAL_HISTOGRAMS="/home/zhuominc/xinrui_projects/vortex_torch/examples/calibration/raw_histograms.npy" +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-4B.npy" SKIP_AUTOTUNE=0 # ── Parse arguments ─────────────────────────────────────────── @@ -135,6 +135,18 @@ MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" RUN_DIR="${RESULTS_DIR}/remap_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" +# Calibration artifacts live on /var/tmp (large disk), keyed by model. +# Example: /var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +# If no explicit --real-histograms and a cached file exists, reuse it. +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + echo "============================================================" echo "Remap Function Benchmark" echo " Model: ${MODEL_NAME}" @@ -166,7 +178,7 @@ if [ -n "${REAL_HISTOGRAMS}" ]; then else echo "" echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" - CALIBRATION_DIR="${RUN_DIR}/calibration" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" mkdir -p "${CALIBRATION_DIR}" python "${BENCH_DIR}/calibrate_topk.py" \ --model-name "${MODEL_NAME}" \ @@ -177,8 +189,11 @@ else --vortex-module-name "${ALGO}" \ --output-dir "${CALIBRATION_DIR}" \ 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" - REAL_HIST_PATH="${CALIBRATION_DIR}/raw_histograms.npy" - echo ">>> Step 1: Done. Calibration saved to ${CALIBRATION_DIR}" + # Promote raw_histograms.npy to the shared per-model cache path. + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" + echo ">>> Step 1: Staging dir (lut/quantiles/logs): ${CALIBRATION_DIR}" fi # Modes 1 (LUT_CDF) and 2 (Quantile) are dropped from the comparison, so @@ -221,6 +236,7 @@ BENCH_EXTRA=() [ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ --remap-bench \ + --per-head-bench \ --batch-sizes "${BATCH_SIZE}" \ --num-kv-heads "${NUM_KV_HEADS}" \ --seq-lens "${SEQ_LEN}" \ diff --git a/third_party/sglang b/third_party/sglang index 47faead5..b7825d08 160000 --- a/third_party/sglang +++ b/third_party/sglang @@ -1 +1 @@ -Subproject commit 47faead5448b14681ac57fc9a3c6311654fc2b17 +Subproject commit b7825d08399fccdf1f29a5380d6601fcef59aca1 diff --git a/vortex_torch/indexer/utils_sglang.py b/vortex_torch/indexer/utils_sglang.py index 74b8cfe6..343207fc 100644 --- a/vortex_torch/indexer/utils_sglang.py +++ b/vortex_torch/indexer/utils_sglang.py @@ -40,7 +40,7 @@ def plan_decode( ctx.max_chunk_size, ctx.min_chunk_size ) - + ctx.set_batch_size(cached_seq_lens.shape[0]) From 13cb8a015a63be6a5cd8060fb6c889915de8b879 Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 19 Apr 2026 19:06:04 -0400 Subject: [PATCH 22/24] Add parallel TopK kernel and profiling enhancements - Introduced for a multi-CTA split+merge variant of the TopK kernel to improve GPU utilization. - Updated to include the new source file for the parallel kernel. - Enhanced to support benchmarking of the parallel kernel, including automatic split determination. - Added a new profiling script for comparing performance between the parallel and fused TopK kernels. - Updated example scripts to facilitate ablation studies on remap functions and kernel performance across different configurations. --- benchmarks/bench_topk.py | 74 +- benchmarks/profile_parallel_vs_fused.py | 99 +++ csrc/register.cc | 11 + csrc/register.h | 18 + csrc/topk_sglang_parallel.cu | 811 ++++++++++++++++++ .../ablation_remap_function_block_size.sh | 279 ++++++ examples/ablation_remap_function_model.sh | 262 ++++++ .../ablation_remap_function_topk_benchmark.sh | 277 ++++++ examples/ablation_remap_function_topk_val.sh | 255 ++++++ examples/analyze_ablation_remap.py | 416 +++++++++ examples/profile_in_docker.sh | 181 ++++ examples/profile_parallel_vs_fused_ncu.sh | 277 ++++++ examples/profile_parallel_vs_fused_nsys.sh | 211 +++++ .../remap_function_bench_topk_parallel.sh | 245 ++++++ examples/verify_algo.py | 29 +- setup.py | 1 + 16 files changed, 3443 insertions(+), 3 deletions(-) create mode 100644 benchmarks/profile_parallel_vs_fused.py create mode 100644 csrc/topk_sglang_parallel.cu create mode 100644 examples/ablation_remap_function_block_size.sh create mode 100644 examples/ablation_remap_function_model.sh create mode 100644 examples/ablation_remap_function_topk_benchmark.sh create mode 100644 examples/ablation_remap_function_topk_val.sh create mode 100644 examples/analyze_ablation_remap.py create mode 100755 examples/profile_in_docker.sh create mode 100755 examples/profile_parallel_vs_fused_ncu.sh create mode 100755 examples/profile_parallel_vs_fused_nsys.sh create mode 100755 examples/remap_function_bench_topk_parallel.sh diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 3653c55a..a717c7b5 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -27,6 +27,7 @@ topk_output_sglang, # 2-stage radix approximate topk (unmapped baseline) topk_output_sglang_fused, # fused remap + 2-stage radix topk topk_output_sglang_ori, # original SGLang reference kernel + topk_output_sglang_parallel, # multi-CTA split+merge variant of the fused kernel topk_remap_only, # standalone value-space remap topk_profile_histogram, topk_profile_counters, @@ -76,6 +77,32 @@ _AUTOTUNE_TIE_TOLERANCE_MS = 0.0002 # ≈ CUDA event noise floor at this kernel size +def _auto_num_splits(eff_batch_size: int, pages_per_seg: int, topk_val: int) -> int: + """Pick num_splits to balance Phase-1 and Phase-2 work on the parallel + kernel. + + Phase-1 per CTA does O(pages/splits) work and runs eff_batch_size*splits + CTAs in parallel; Phase-2 runs eff_batch_size CTAs each doing + O(splits*topk) work on the merged candidate list. Assuming both phases + hit SM saturation, total ≈ (pages/splits + splits*topk)/throughput, + minimized at splits = sqrt(pages/topk). Cap at the SM-budget for + eff_batch_size and the max_safe value (pages_per_seg // topk_val, past + which Phase 1 partitions are smaller than topk_val and gain nothing). + + Returns 1 when splitting cannot help. + """ + max_safe = max(1, pages_per_seg // max(1, topk_val)) + if max_safe <= 1 or eff_batch_size <= 0: + return 1 + try: + sm = torch.cuda.get_device_properties(0).multi_processor_count + except Exception: + sm = 132 + balanced = max(1, int(round((pages_per_seg / max(1, topk_val)) ** 0.5))) + sm_budget = max(1, sm // max(1, eff_batch_size)) + return max(1, min(balanced, sm_budget, max_safe)) + + def _load_autotune_hparams(path: str) -> Dict[int, float]: """Load per-mode best hyperparameters from an autotune_results.json. @@ -514,6 +541,8 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": naive_ms, "split_total_ms": None, "fused_ms": None, + "parallel_ms": None, + "parallel_splits": None, "threshold_bin_mean": 0.0, "threshold_bin_max": 0.0, "num_above_mean": 0.0, @@ -541,6 +570,8 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": baseline["mean_ms"], "split_total_ms": None, "fused_ms": None, + "parallel_ms": None, + "parallel_splits": None, **none_stats, }) @@ -556,6 +587,8 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": sglang_ori_ms, "split_total_ms": None, "fused_ms": None, + "parallel_ms": None, + "parallel_splits": None, "threshold_bin_mean": 0.0, "threshold_bin_max": 0.0, "num_above_mean": 0.0, @@ -595,6 +628,32 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, inputs["sparse_kv_indices"].zero_() fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) + # Multi-CTA split+merge variant of the fused kernel. num_splits <= 1 + # delegates to the single-CTA fused path, so this is only a + # meaningful extra data point when we can actually split. + parallel_ms = None + parallel_splits_used = None + if getattr(args, "bench_parallel", False): + splits = getattr(args, "num_splits", -1) + if splits is None or splits < 1: + splits = _auto_num_splits(eff_bs, pages_per_seg, topk_val) + parallel_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + splits, + mode, power, lut_t, q_t, + ) + inputs["sparse_kv_indices"].zero_() + parallel = bench_kernel( + topk_output_sglang_parallel, parallel_args, args.warmup, args.repeat + ) + parallel_ms = parallel["mean_ms"] + parallel_splits_used = splits + # Split-phase timing is only meaningful for arithmetic modes. # MAPPING_LUT_CDF / QUANTILE / TRUNC8 apply their mapping inside # compute_stage1_bin, which topk_remap_only cannot reproduce, so we @@ -645,6 +704,8 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": topk_after_remap_ms, "split_total_ms": split_total_ms, "fused_ms": fused["mean_ms"], + "parallel_ms": parallel_ms, + "parallel_splits": parallel_splits_used, **stats, } config["modes"].append(row) @@ -664,7 +725,7 @@ def _print_remap_table(results: List[dict]) -> None: # still in the JSON for downstream tools, just not in the table. header = ( f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " - f"{'fused_ms':>9s} {'base_ms':>9s} " + f"{'fused_ms':>9s} {'par_ms':>9s} {'splits':>6s} {'base_ms':>9s} " f"{'s1p2_load':>9s} {'eff_thr':>7s} {'rounds':>6s} {'s2_work':>8s}" ) for cfg in results: @@ -705,6 +766,9 @@ def _print_remap_table(results: List[dict]) -> None: def _fmt(v): return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" fused_str = _fmt(row.get("fused_ms")) + par_str = _fmt(row.get("parallel_ms")) + splits = row.get("parallel_splits") + splits_str = f"{splits:>6d}" if splits is not None else f"{'N/A':>6s}" thr_size = row.get("threshold_bin_size_mean", 0.0) rounds = row.get("refine_rounds_mean", 0.0) eff_thr = min(thr_size, float(_STAGE2_SMEM_CAP)) @@ -716,6 +780,8 @@ def _fmt(v): f"{_fmt(row['topk_after_remap_ms'])} " f"{_fmt(row['split_total_ms'])} " f"{fused_str} " + f"{par_str} " + f"{splits_str} " f"{base_ms:9.4f} " f"{s1p2_load:9.0f} " f"{eff_thr:7.0f} " @@ -768,6 +834,7 @@ def _mean_or_none(vals): # across heads by construction. NUMERIC_KEYS = ( "remap_ms", "topk_after_remap_ms", "split_total_ms", "fused_ms", + "parallel_ms", "threshold_bin_mean", "threshold_bin_max", "num_above_mean", "threshold_bin_size_mean", "threshold_bin_size_max", @@ -984,6 +1051,11 @@ def main(): p.add_argument("--output-json", type=str, default=None) p.add_argument("--remap-bench", action="store_true", help="Run the split-phase remap/topk/fused/baseline benchmark.") + p.add_argument("--bench-parallel", action="store_true", + help="Also time topk_output_sglang_parallel (multi-CTA split+merge).") + p.add_argument("--num-splits", type=int, default=-1, + help="Partitions per batch for the parallel kernel. -1 = auto " + "(sm_count / eff_batch_size, clamped to pages_per_seg/topk_val).") p.add_argument("--per-head-bench", action="store_true", help="In addition to the aggregated 'real'-distribution table, also " "run the remap-bench once per KV head: slice the calibrated " diff --git a/benchmarks/profile_parallel_vs_fused.py b/benchmarks/profile_parallel_vs_fused.py new file mode 100644 index 00000000..ecfd8723 --- /dev/null +++ b/benchmarks/profile_parallel_vs_fused.py @@ -0,0 +1,99 @@ +""" +Driver for Nsight Compute profiling of the parallel vs fused TopK +kernels. Designed to be launched under `ncu` with --launch-skip and +--launch-count to isolate a specific kernel launch from warmup. + +The script does exactly: + args.warmup matching-kernel launches (skipped by ncu --launch-skip) + args.iters matching-kernel launches (captured by ncu --launch-count) + +Pair --launch-skip/--launch-count with --kernel-name so unrelated +launches (torch initializers, cublas, etc.) don't pollute the counts. +""" +import argparse +import torch +from vortex_torch_C import ( + topk_output_sglang_fused, + topk_output_sglang_parallel, +) + + +def make_inputs(eff_bs: int, pages: int, topk: int): + reserved = 0 + dense_indptr = torch.arange( + 0, (eff_bs + 1) * pages, pages, dtype=torch.int32, device="cuda" + ) + sparse_indptr = torch.arange( + 0, (eff_bs + 1) * topk, topk, dtype=torch.int32, device="cuda" + ) + dense_indices = torch.arange(eff_bs * pages, dtype=torch.int32, device="cuda") + torch.manual_seed(0) + x = torch.randn(eff_bs * pages, 1, 1, dtype=torch.bfloat16, device="cuda") + out = torch.zeros(eff_bs * topk, dtype=torch.int32, device="cuda") + return x, dense_indptr, sparse_indptr, dense_indices, out, reserved + + +def main(): + p = argparse.ArgumentParser() + p.add_argument( + "--config", + choices=["A", "B"], + required=True, + help="A: topk=2048 pages=32K ; B: topk=30 pages=2K", + ) + p.add_argument("--eff-bs", type=int, default=1) + p.add_argument( + "--mode", type=int, choices=[15, 16], required=True, + help="15=MAPPING_SHIFT_POW2, 16=MAPPING_SHIFT_POW3", + ) + p.add_argument( + "--power", type=float, default=0.5, + help="Pivot (p) for the shift_pow transforms. 0.5 matches the " + "autotune default for Qwen3-1.7B softmax scores.", + ) + p.add_argument("--num-splits", type=int, default=4) + p.add_argument("--kernel", choices=["fused", "parallel"], required=True) + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--iters", type=int, default=1) + args = p.parse_args() + + pages, topk = (32768, 2048) if args.config == "A" else (2048, 30) + x, dense_indptr, sparse_indptr, dense_indices, out, reserved = make_inputs( + args.eff_bs, pages, topk + ) + + if args.kernel == "fused": + def call(): + topk_output_sglang_fused( + x, dense_indptr, sparse_indptr, dense_indices, out, + args.eff_bs, topk, reserved, reserved, pages, + args.mode, args.power, None, None, + ) + else: + def call(): + topk_output_sglang_parallel( + x, dense_indptr, sparse_indptr, dense_indices, out, + args.eff_bs, topk, reserved, reserved, pages, + args.num_splits, args.mode, args.power, None, None, + ) + + # Warmup: specialised kernel is JIT-instantiated and cudaFuncSetAttribute + # is cached; these launches dominate the first-call overhead and we want + # ncu to skip past them. + for _ in range(args.warmup): + call() + torch.cuda.synchronize() + + # Profiled region. Wrap in NVTX so the same script is also useful under + # Nsight Systems (nsys) if you prefer a timeline view. + torch.cuda.nvtx.range_push( + f"profile-{args.kernel}-mode{args.mode}-cfg{args.config}-eff{args.eff_bs}" + ) + for _ in range(args.iters): + call() + torch.cuda.synchronize() + torch.cuda.nvtx.range_pop() + + +if __name__ == "__main__": + main() diff --git a/csrc/register.cc b/csrc/register.cc index 8aa5aea1..af584d37 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -30,6 +30,17 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_power"), py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none()); + m.def("topk_output_sglang_parallel", &topk_output_sglang_parallel, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("num_splits"), + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("mapping_lut") = py::none(), + py::arg("mapping_quantiles") = py::none()); m.def("topk_remap_only", &topk_remap_only, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("remapped"), diff --git a/csrc/register.h b/csrc/register.h index afdb97f0..e5a26def 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -126,6 +126,24 @@ std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt ); +void topk_output_sglang_parallel( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t num_splits, +const int64_t mapping_mode, +const double mapping_power, +std::optional mapping_lut = std::nullopt, +std::optional mapping_quantiles = std::nullopt +); + void topk_remap_only( const at::Tensor& x, const at::Tensor& dense_kv_indptr, diff --git a/csrc/topk_sglang_parallel.cu b/csrc/topk_sglang_parallel.cu new file mode 100644 index 00000000..72193917 --- /dev/null +++ b/csrc/topk_sglang_parallel.cu @@ -0,0 +1,811 @@ +/** + * Vortex TopK parallel kernel (single-kernel, last-CTA-wins merge). + * + * Motivation: the single-CTA fused kernel in topk_sglang.cu pins each + * batch segment to one CTA, which underutilises the GPU for small + * effective batch sizes (e.g. bs=4 on H100 leaves ~97% of SMs idle). + * + * This kernel launches `num_splits * eff_batch_size` CTAs in a single + * launch. CTAs sharing the same `bx` (batch index) partition that + * batch's score range `num_splits` ways and each compute a per-partition + * top-K via the same two-stage radix the fused kernel uses. Partial + * results are written into a per-batch workspace. + * + * Merge is done WITHOUT a second kernel launch. Each CTA, after + * finishing its partition's top-K, does `atomicAdd(&done_counter[bx], + * 1)`. The CTA whose atomicAdd returns `num_splits - 1` is the last + * one to arrive for batch bx, and it alone carries out the merge: + * reads the `num_splits * topk_val` candidates from the workspace, + * runs a small two-stage radix on the already-remapped keys, writes + * final top-K page IDs to sparse_kv_indices. + * + * Correctness: per-partition top-K is a conservative upper bound on + * the global top-K (worst case: all top-K items land in one + * partition). Every global top-K item is therefore guaranteed to be + * in some partition's top-K, and the merge picks the final top-K + * from the union — sorted-scores match the fused kernel exactly. + * Tie-breaking can differ because radix tie-breaks depend on atomic + * race order. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "register.h" + +namespace { + +// ---- Launch constants (match topk_sglang.cu) -------------------------------- + +constexpr int kThreadsPerBlock = 1024; + +#ifdef USE_ROCM +#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES +constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); +#else +constexpr size_t kSmem = 48 * 1024; +#endif +#else +constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32 KB +#endif + +constexpr size_t kFusedSmemMax = 96 * 1024; // combined kernel dynamic smem ceiling +constexpr int VORTEX_MAX_TOPK = 2048; + +// ---- Program-lifetime done-counter array ---------------------------------- +// Used by the last-CTA-wins barrier. __device__ linkage → zero-initialised at +// program startup. atomicInc(ptr, num_splits-1) cycles each entry back to 0 +// after every launch, so we never pay a cudaMemset on entry to the host fn. +// Sized for the largest realistic effective batch we'd ever run through the +// parallel kernel (decode bs×heads). Host validates the cap. +constexpr int kMaxParallelEffBs = 8192; +__device__ int g_parallel_done_counter[kMaxParallelEffBs]; + +// ---- Device helpers (duplicated from topk_sglang.cu) ----------------------- + +__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { + const uint32_t bits = __float_as_uint(x); + const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + return static_cast((key >> 16) & 0xFFu); +} + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +#include "topk_mapping.cuh" + +// ============================================================================ +// fast_topk_partition +// +// Per-partition two-stage radix. Same algorithm as the fused kernel's +// fast_topk_clean_fused in topk_sglang.cu, with identical mapping-mode +// dispatch and bucket selection. Returns slice-local indices of the +// top `target_k` elements in `index`. +// +// Reuses the caller-provided extern shared memory region `f_input_idx` +// (2 × SMEM_INPUT_SIZE ints) and the `s_bins` byte cache immediately +// after it. The caller also supplies the static histogram / counter +// storage through the template's body — each device-function-private +// __shared__ declaration gets its own offset, but total static smem +// stays small enough to fit comfortably alongside the dynamic region. +// ============================================================================ +template +__device__ void fast_topk_partition( + const ScoreT* __restrict__ input, + int* __restrict__ index, + int* __restrict__ f_input_idx_raw, // 2 × SMEM_INPUT_SIZE ints + uint8_t* __restrict__ s_bins, // `length` bytes + int row_start, + int length, + int target_k, + const TopKMappingParams mapping) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + + alignas(128) __shared__ int f_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int f_counter; + alignas(128) __shared__ int f_threshold_bin_id; + alignas(128) __shared__ int f_num_input[2]; + + auto& f_histogram = f_histogram_buf[0]; + + // Treat the caller's extern-smem region as two banks of SMEM_INPUT_SIZE ints. + auto f_input_idx = [&](int bank, int pos) -> int& { + return f_input_idx_raw[bank * SMEM_INPUT_SIZE + pos]; + }; + + const int tx = threadIdx.x; + + constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); + + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + + // Stage 1 pass 1: bin every element and cache the bin in s_bins so + // pass 2 doesn't re-load scores or re-apply the mapping. + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + int bin; + if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(remapped)); + else bin = static_cast(convert_to_uint8(remapped)); + s_bins[idx] = static_cast(bin); + ::atomicAdd(&f_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = f_histogram_buf[k][tx]; + if (tx < RADIX - j) value += f_histogram_buf[k][tx + j]; + f_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[0] = 0; + f_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const int bin = static_cast(s_bins[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + + constexpr int sub_bin_offset_start = use_dense_bucket ? 8 : 24; + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + const int bin = static_cast(s_bins[idx]); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto pos = ::atomicAdd(&f_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx(0, pos) = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + + constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; + constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; + __shared__ int f_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = f_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input + : int(SMEM_INPUT_SIZE); + run_cumsum(); + if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { + f_threshold_bin_id = tx; + f_num_input[r_idx ^ 1] = 0; + f_last_remain = topk - f_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = f_threshold_bin_id; + topk -= f_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx(r_idx, i); + const auto offset = stage2_offset_start - round * 8; + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) f_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = f_input_idx(r_idx, i); + const float raw = vortex_to_float(input[idx + row_start]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&f_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == stage2_max_rounds - 1) { + const auto pos = ::atomicAdd(&f_last_remain, -1); + if (pos > 0) index[target_k - pos] = idx; + } else { + const auto pos = ::atomicAdd(&f_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + f_input_idx(r_idx ^ 1, pos) = idx; + const auto b32 = convert_to_uint32(remapped); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&f_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// ============================================================================ +// fast_topk_merge +// +// Run by the last-arriving CTA of each batch. Input is the combined +// candidate list (`num_splits * topk_val` float keys + int indices, +// with idx==-1 marking sentinel slots). Reuses the same extern-smem +// region `s_input_idx_raw` that Phase 1 used — its earlier contents +// are dead at this point. Output: top-`target_k` positions into +// `index`, indexing the combined candidate list. +// +// Bucketing matches the fused kernel's bucketing for the given MODE +// so the merged top-K is lossless modulo atomic tie-break order. +// ============================================================================ +template +__device__ void fast_topk_merge( + const float* __restrict__ input, + const int* __restrict__ valid_mask, + int* __restrict__ index, + int* __restrict__ s_input_idx_raw, // 2 × SMEM_INPUT_SIZE ints + int row_start, + int length, + int target_k) +{ + int topk = target_k; + constexpr auto BLOCK_SIZE = 1024; + constexpr auto RADIX = 256; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); + constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; + constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; + + alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin_id; + alignas(128) __shared__ int s_num_input[2]; + + auto& s_histogram = s_histogram_buf[0]; + auto s_input_idx = [&](int bank, int pos) -> int& { + return s_input_idx_raw[bank * SMEM_INPUT_SIZE + pos]; + }; + + const int tx = threadIdx.x; + + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + if (valid_mask[idx + row_start] < 0) continue; // sentinel; skip + const float v = input[idx + row_start]; + int bin; + if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(v)); + else bin = static_cast(convert_to_uint8(v)); + ::atomicAdd(&s_histogram[bin], 1); + } + __syncthreads(); + + const auto run_cumsum = [&] { +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const auto j = 1 << i; + const auto k = i & 1; + auto value = s_histogram_buf[k][tx]; + if (tx < RADIX - j) value += s_histogram_buf[k][tx + j]; + s_histogram_buf[k ^ 1][tx] = value; + } + __syncthreads(); + } + }; + + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[0] = 0; + s_counter = 0; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + if (valid_mask[idx + row_start] < 0) continue; + const float v = input[idx + row_start]; + int bin; + if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(v)); + else bin = static_cast(convert_to_uint8(v)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + return; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + + for (int idx = tx; idx < length; idx += BLOCK_SIZE) { + if (valid_mask[idx + row_start] < 0) continue; + const auto raw_input = input[idx + row_start]; + int bin; + if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(raw_input)); + else bin = static_cast(convert_to_uint8(raw_input)); + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + const auto pos = ::atomicAdd(&s_num_input[0], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx(0, pos) = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> stage2_offset_start) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + __syncthreads(); + } + +#pragma unroll 4 + for (int round = 0; round < 4; ++round) { + if (round >= stage2_max_rounds) break; + __shared__ int s_last_remain; + const auto r_idx = round % 2; + + const auto _raw_num_input = s_num_input[r_idx]; + const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input + : int(SMEM_INPUT_SIZE); + run_cumsum(); + if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { + s_threshold_bin_id = tx; + s_num_input[r_idx ^ 1] = 0; + s_last_remain = topk - s_histogram[tx + 1]; + } + __syncthreads(); + + const auto threshold_bin = s_threshold_bin_id; + topk -= s_histogram[threshold_bin + 1]; + + if (topk == 0) { + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx(r_idx, i); + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } + } + __syncthreads(); + break; + } else { + __syncthreads(); + if (tx < RADIX + 1) s_histogram[tx] = 0; + __syncthreads(); + for (int i = tx; i < num_input; i += BLOCK_SIZE) { + const auto idx = s_input_idx(r_idx, i); + const auto raw_input = input[idx + row_start]; + const auto offset = stage2_offset_start - round * 8; + const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; + if (bin > threshold_bin) { + const auto pos = ::atomicAdd(&s_counter, 1); + index[pos] = idx; + } else if (bin == threshold_bin) { + if (round == stage2_max_rounds - 1) { + const auto pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) index[target_k - pos] = idx; + } else { + const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); + if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { + s_input_idx(r_idx ^ 1, pos) = idx; + const auto b32 = convert_to_uint32(raw_input); + const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; + ::atomicAdd(&s_histogram[sub_bin], 1); + } + } + } + } + __syncthreads(); + } + } +} + +// ============================================================================ +// Combined kernel. +// +// Grid: (num_splits, eff_batch_size). Every CTA: +// 1. Computes its partition's top-K (fast_topk_partition). +// 2. Writes (remapped key, batch-local idx) pairs + sentinels to the +// per-batch workspace slot. +// 3. __threadfence() to publish the writes, then atomicAdd on the +// per-batch done-counter. The CTA whose atomicAdd returns +// num_splits - 1 is the last one for this batch. +// 4. If last: run the merge (fast_topk_merge) on the combined +// num_splits*topk_val candidates and write final page IDs to +// sparse_kv_indices. Other CTAs exit. +// ============================================================================ +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopKOutput_Parallel_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + float* __restrict__ partial_keys, // [eff_bs * num_splits * topk_val] + int* __restrict__ partial_idx, // [eff_bs * num_splits * topk_val] + const int topk_val, + const int num_splits, + const int page_reserved_bos, + const int page_reserved_eos, + const int chunk_bytes, // smem bytes reserved for s_bins + const TopKMappingParams mapping) +{ + // ---- Dynamic smem layout ------------------------------------------------- + // [ f_input_idx (2 × SMEM_INPUT_SIZE ints = kSmem bytes) + // s_bins (chunk_bytes, only valid during Phase 1) ] + // The merge doesn't touch s_bins, so its extern region overlaps + // f_input_idx harmlessly. + extern __shared__ int smem_scratch[]; + constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); + int* f_input_idx_raw = smem_scratch; + uint8_t* s_bins = reinterpret_cast(&smem_scratch[2 * SMEM_INPUT_SIZE]); + (void)chunk_bytes; // sizing is the host's responsibility; kernel just uses it + + // s_indices doubles as the partition's radix output AND the merge's radix + // output — they run sequentially on the same CTA, so the same ~2K slots + // are reused. Stores up to VORTEX_MAX_TOPK = 2048 entries. + __shared__ int s_indices[VORTEX_MAX_TOPK]; + // Broadcasts whether this CTA is the last-arriving one for its batch. + __shared__ int s_is_last; + + const int p = blockIdx.x; + const int bx = blockIdx.y; + const int tx = threadIdx.x; + + const int start = dense_kv_indptr[bx] + page_reserved_bos; + const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; + const int total_len = end - start; + + // Short batch: fused kernel returns without writing; match that. + if (total_len <= topk_val) return; + + const size_t slot_base = (static_cast(bx) * num_splits + p) * topk_val; + float* keys_out = partial_keys + slot_base; + int* idx_out = partial_idx + slot_base; + + const int chunk = (total_len + num_splits - 1) / num_splits; + const int part_start = p * chunk; + const int raw_part_end = part_start + chunk; + const int part_end = raw_part_end < total_len ? raw_part_end : total_len; + const int part_len = (part_end > part_start) ? (part_end - part_start) : 0; + + // Sentinel tail: merge filters these by idx == -1. Only fill the range + // that won't be overwritten with real data. + const int real_fill = (part_len < topk_val) ? part_len : topk_val; + const int tail_count = topk_val - real_fill; + if (tail_count > 0) { + for (int i = tx; i < tail_count; i += blockDim.x) { + keys_out[real_fill + i] = -CUDART_INF_F; + idx_out [real_fill + i] = -1; + } + __syncthreads(); + } + + const ScoreT* __restrict__ slice_ptr = score + start + part_start; + + // ---- Phase 1: per-partition top-K --------------------------------------- + if (part_len > 0) { + if (part_len <= topk_val) { + // Whole slice fits under topk_val — emit it directly. + for (int i = tx; i < part_len; i += blockDim.x) { + const float raw = vortex_to_float(slice_ptr[i]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + keys_out[i] = remapped; + idx_out [i] = part_start + i; + } + } else { + fast_topk_partition( + slice_ptr, s_indices, f_input_idx_raw, s_bins, + 0, part_len, topk_val, mapping); + __syncthreads(); + for (int i = tx; i < topk_val; i += blockDim.x) { + const int sl = s_indices[i]; + const float raw = vortex_to_float(slice_ptr[sl]); + const float remapped = apply_transform_tmpl(raw, mapping.power_exp); + keys_out[i] = remapped; + idx_out [i] = part_start + sl; + } + } + } + + // Publish workspace writes so the last-CTA can observe them. + __threadfence(); + __syncthreads(); + + // ---- Arrive at the barrier via atomicInc -------------------------------- + // atomicInc(ptr, N-1) stores `((old >= N-1) ? 0 : old+1)` and returns old. + // So with N == num_splits the counter cycles 0→1→…→N-1→0 per call, which + // means we never need to memset done_counter between calls — after the + // last-CTA's increment it's back at 0, ready for the next launch. + // (Relies on the caller allocating done_counter zero-initialised once.) + if (tx == 0) { + const unsigned int old = ::atomicInc( + reinterpret_cast(&g_parallel_done_counter[bx]), + static_cast(num_splits - 1)); + s_is_last = (old == static_cast(num_splits - 1)) ? 1 : 0; + } + __syncthreads(); + + if (s_is_last == 0) return; + + // ---- Merge: last CTA selects final top-K -------------------------------- + const int candidate_len = num_splits * topk_val; + const size_t batch_base = static_cast(bx) * candidate_len; + const float* keys_blk = partial_keys + batch_base; + const int* idx_blk = partial_idx + batch_base; + int* out_blk = sparse_kv_indices + + sparse_kv_indptr[bx] + + page_reserved_bos; + const int* dense_blk = dense_kv_indices + start; + + fast_topk_merge( + keys_blk, idx_blk, s_indices, f_input_idx_raw, + 0, candidate_len, topk_val); + __syncthreads(); + + for (int i = tx; i < topk_val; i += blockDim.x) { + const int pos = s_indices[i]; + const int batch_local = idx_blk[pos]; + out_blk[i] = (batch_local >= 0) ? dense_blk[batch_local] : -1; + } +} + +// ---- setup_kernel_smem_once (duplicated) ----------------------------------- + +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { +#ifdef USE_ROCM + return ::cudaFuncSetAttribute( + reinterpret_cast(f), + ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#else + return ::cudaFuncSetAttribute( + f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); +#endif + }(); + TORCH_CHECK(result == cudaSuccess, + "set_up_kernel_once (parallel) failed:", ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +// ============================================================================ +// Host entry point. +// +// Signature matches topk_output_sglang_fused plus `num_splits`. +// `num_splits <= 1` delegates to the single-CTA fused kernel so callers +// can unconditionally use this path. +// ============================================================================ +void topk_output_sglang_parallel( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t num_splits, + const int64_t mapping_mode, + const double mapping_power, + std::optional mapping_lut, + std::optional mapping_quantiles) +{ + TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, + "topk_output_sglang_parallel: topk_val (", topk_val, + ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); + TORCH_CHECK(num_splits >= 1, + "topk_output_sglang_parallel: num_splits must be >= 1"); + + if (num_splits <= 1) { + topk_output_sglang_fused( + x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, + sparse_kv_indices, eff_batch_size, topk_val, + reserved_bos, reserved_eos, max_num_pages, + mapping_mode, mapping_power, mapping_lut, mapping_quantiles); + return; + } + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + (void)mapping_lut; + (void)mapping_quantiles; + + TopKMappingParams mapping{}; + mapping.mode = static_cast(mapping_mode); + mapping.power_exp = static_cast(mapping_power); + mapping.lut = nullptr; + mapping.quantiles = nullptr; + + // Dynamic smem = kSmem (f_input_idx) + chunk_bytes (s_bins for the + // partition radix; the merge doesn't touch s_bins). + const int64_t chunk_pages = (max_num_pages + num_splits - 1) / num_splits; + const size_t chunk_bytes = (static_cast(chunk_pages) + size_t(15)) & ~size_t(15); + const size_t smem_bytes = kSmem + chunk_bytes; + TORCH_CHECK(smem_bytes <= kFusedSmemMax, + "topk_output_sglang_parallel: smem ", smem_bytes, + " exceeds ceiling ", kFusedSmemMax); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + TORCH_CHECK(eff_batch_size <= kMaxParallelEffBs, + "topk_output_sglang_parallel: eff_batch_size (", eff_batch_size, + ") exceeds kMaxParallelEffBs (", kMaxParallelEffBs, + "). Raise the __device__ counter array size."); + + // Per-call workspace. at::empty, no zero-init — kernel fills every used + // slot (valid prefix + sentinel tail). done_counter is a __device__ + // global (above) so no workspace allocation needed for it. + const int64_t ws_elems = eff_batch_size * num_splits * topk_val; + auto opts_f32 = at::TensorOptions().device(x.device()).dtype(at::kFloat); + auto opts_i32 = at::TensorOptions().device(x.device()).dtype(at::kInt); + at::Tensor partial_keys = at::empty({ws_elems}, opts_f32); + at::Tensor partial_idx = at::empty({ws_elems}, opts_i32); + + dim3 grid(static_cast(num_splits), + static_cast(eff_batch_size)); + dim3 nthreads(kThreadsPerBlock); + + #define VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once< \ + TopKOutput_Parallel_Kernel, \ + kFusedSmemMax>(); \ + TopKOutput_Parallel_Kernel \ + <<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + partial_keys.data_ptr(), \ + partial_idx.data_ptr(), \ + static_cast(topk_val), \ + static_cast(num_splits), \ + static_cast(reserved_bos), \ + static_cast(reserved_eos), \ + static_cast(chunk_bytes), \ + mapping); \ + } while (0) + + #define VORTEX_PARALLEL_DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping.mode) { \ + case MAPPING_NONE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_TRUNC8: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_TRUNC8); break; \ + case MAPPING_ERF: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP:VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + case MAPPING_DENSE_MANT: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_DENSE_MANT); break; \ + default: \ + TORCH_CHECK(false, \ + "topk_output_sglang_parallel: unsupported mapping_mode ", \ + mapping.mode); \ + } \ + } while (0) + + if (x.scalar_type() == at::ScalarType::BFloat16) { + VORTEX_PARALLEL_DISPATCH_MODE( + __nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); + } else if (x.scalar_type() == at::ScalarType::Float) { + VORTEX_PARALLEL_DISPATCH_MODE(float, x.data_ptr()); + } else { + TORCH_CHECK(false, "topk_output_sglang_parallel: unsupported dtype ", + x.scalar_type()); + } + + #undef VORTEX_PARALLEL_DISPATCH_MODE + #undef VORTEX_PARALLEL_DISPATCH + + const auto result = cudaGetLastError(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_sglang_parallel kernel failed: ", + ::cudaGetErrorString(result)); +} diff --git a/examples/ablation_remap_function_block_size.sh b/examples/ablation_remap_function_block_size.sh new file mode 100644 index 00000000..4bf5bba5 --- /dev/null +++ b/examples/ablation_remap_function_block_size.sh @@ -0,0 +1,279 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. block (page) size +# +# Sweeps BLOCK_SIZE and, for every cell, runs the full +# calibrate -> autotune -> remap-bench +# pipeline so the per-mode hyperparameter is freshly chosen by +# autotune for that block size (NOT hardcoded). +# +# Mapping modes under test (matches the screenshot): +# 0 none — unmapped baseline +# 3 power — p +# 6 asinh — beta +# 7 log1p — alpha +# 9 erf — alpha +# 10 tanh — alpha +# 11 subtract — pivot +# 13 exp_stretch — alpha +# 15 shift_pow2 — pivot +# 16 shift_pow3 — pivot +# 17 linear_steep — k +# +# Output: +# results/ablation_remap_block_size_/ +# bs/{autotune_results.json, remap_bench.json, step{1,2,3}_*.log} +# sweep_index.json +# selected_hparams.txt — per-cell screenshot-style summary +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +BLOCK_SIZES="1 2 4 8 16 32 64" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --block-sizes) BLOCK_SIZES="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_block_size_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +# Per-model calibration cache (reused across block_size cells: page size +# does not change the per-segment score distribution). +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs block_size" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block sizes: ${BLOCK_SIZES}" +echo " Batch size: ${BATCH_SIZE}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Step 0: Calibrate once for this model ────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: SKIPPED calibration (using ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo ">>> Step 0: Calibrating ${MODEL_NAME} for raw_histograms.npy" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size 1 \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 0: Done. raw_histograms -> ${REAL_HIST_PATH}" +fi + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +echo "{" > "${SWEEP_INDEX}" +echo " \"axis_name\": \"block_size\"," >> "${SWEEP_INDEX}" +echo " \"axis_type\": \"kernel\"," >> "${SWEEP_INDEX}" +echo " \"model_name\": \"${MODEL_NAME}\"," >> "${SWEEP_INDEX}" +echo " \"topk_val\": ${TOPK_VAL}," >> "${SWEEP_INDEX}" +echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," >> "${SWEEP_INDEX}" +echo " \"cells\": [" >> "${SWEEP_INDEX}" + +FIRST_CELL=1 +for BLOCK_SIZE in ${BLOCK_SIZES}; do + # Pick a seq_len that satisfies pages/seg > topk_val + 3 reserved. + MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) + SEQ_LEN=${MIN_SEQ_LEN} + # Round up to next power-of-two-ish multiple of 1024 for stable timing. + if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + + CELL_DIR="${SWEEP_DIR}/bs${BLOCK_SIZE}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: block_size=${BLOCK_SIZE} seq_len=${SEQ_LEN}" + echo "============================================================" + + echo ">>> Autotuning hparams for block_size=${BLOCK_SIZE}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench for block_size=${BLOCK_SIZE}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import MODE_NAMES, PARAM_NAME +except Exception: + MODE_NAMES = {0: "none", 3: "power", 6: "asinh", 7: "log1p", 9: "erf", + 10: "tanh", 11: "subtract", 13: "exp_stretch", + 15: "shift_pow2", 16: "shift_pow3", 17: "linear_steep"} + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} + +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path = sys.argv[1], sys.argv[2] +with open(idx_path) as f: + idx = json.load(f) + +lines = ["== Selected mapping functions (autotuned, block_size sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m not in best: + continue + pname = PARAM_NAME.get(m, "p") + pval = best[m].get("param", 0.0) + parts.append(f"{DISPLAY[m]}({pname}={pval})") + lines.append(f"[block_size={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "Block-size ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Per-cell results: ${SWEEP_DIR}/bs/" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_model.sh b/examples/ablation_remap_function_model.sh new file mode 100644 index 00000000..0212b83a --- /dev/null +++ b/examples/ablation_remap_function_model.sh @@ -0,0 +1,262 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. model +# +# Sweeps MODEL_NAME across the Qwen3 family. For every model: +# 1. Calibrate (or reuse cached raw_histograms_.npy) +# 2. Autotune the per-mode hparam on that model's histogram +# (NOT hardcoded; freshly tuned per model) +# 3. Remap-bench across the autotuned hparams +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODELS="Qwen/Qwen3-0.6B Qwen/Qwen3-1.7B Qwen/Qwen3-4B Qwen/Qwen3-8B" +TOPK_VAL=2048 +BLOCK_SIZE=1 +MEM=0.7 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --models) MODELS="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +# Per-model max-total-tokens (KV pool cap for calibration). Larger models +# need a smaller cap so they fit at MEM=0.7. Override by passing the env +# var MAX_TOTAL_TOKENS_=N before invocation. +declare -A MAX_TOTAL_TOKENS_LUT +MAX_TOTAL_TOKENS_LUT["qwen3-0.6B"]=131072 +MAX_TOTAL_TOKENS_LUT["qwen3-1.7B"]=64768 +MAX_TOTAL_TOKENS_LUT["qwen3-4B"]=32768 +MAX_TOTAL_TOKENS_LUT["qwen3-8B"]=16384 + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_model_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +mkdir -p "${CALIBRATION_BASE}" + +echo "============================================================" +echo "Ablation: remap function vs model" +echo " Models: ${MODELS}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"model\"," + echo " \"axis_type\": \"kernel\"," + echo " \"topk_val\": ${TOPK_VAL}," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +# Pick a single seq_len that satisfies pages/seg > topk_val for all models. +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +SEQ_LEN=${MIN_SEQ_LEN} +if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + +FIRST_CELL=1 +for MODEL_NAME in ${MODELS}; do + MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" + MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" + DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" + + # Per-model max-total-tokens (override-able via env). + MTT_DEFAULT="${MAX_TOTAL_TOKENS_LUT[${MODEL_TAG}]:-32768}" + ENV_KEY="MAX_TOTAL_TOKENS_$(echo "${MODEL_TAG}" | tr '.-' '__')" + MAX_TOTAL_TOKENS="${!ENV_KEY:-${MTT_DEFAULT}}" + + CELL_DIR="${SWEEP_DIR}/${MODEL_SLUG}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: model=${MODEL_NAME} (max_total_tokens=${MAX_TOTAL_TOKENS})" + echo "============================================================" + + # Step 1: calibrate (cached per-model) + if [ -f "${DEFAULT_REAL_HIST}" ]; then + echo ">>> Calibration cache hit: ${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + else + echo ">>> Calibrating ${MODEL_NAME}" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${CELL_DIR}/step1_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + fi + + # Step 2: autotune + echo ">>> Autotuning hparams for ${MODEL_NAME}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + # Step 3: remap bench + echo ">>> Remap bench for ${MODEL_NAME}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "model" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "Model ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_topk_benchmark.sh b/examples/ablation_remap_function_topk_benchmark.sh new file mode 100644 index 00000000..7952bd20 --- /dev/null +++ b/examples/ablation_remap_function_topk_benchmark.sh @@ -0,0 +1,277 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. topk-kernel benchmark workload +# +# Sweeps the kernel-bench INPUT distribution (the workload that +# stresses the TopK kernel) and, per cell, runs autotune + +# remap-bench so the per-mode hparam is freshly chosen for that +# distribution. This is the robustness ablation: do the +# autotuned remap functions still beat the unmapped baseline +# when the input score distribution shifts? +# +# Distributions available in bench_topk.py: +# normal — N(0,1) per-page scores +# lognormal — heavy-tailed positive scores +# uniform — U[0,1) +# bucket_uniform— per-bucket uniform (worst case for radix) +# real — sampled from raw_histograms_.npy +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +BLOCK_SIZE=1 +SEQ_LEN=32768 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# Distributions to sweep (one per cell). "real" requires raw_histograms.npy. +DISTRIBUTION_LIST="normal lognormal uniform bucket_uniform real" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --dist-list|--distributions) DISTRIBUTION_LIST="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + SEQ_LEN=${MIN_SEQ_LEN} +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_topk_benchmark_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +# Need raw_histograms only if "real" is in the distribution list. +NEED_REAL=0 +for d in ${DISTRIBUTION_LIST}; do + if [ "$d" = "real" ]; then NEED_REAL=1; fi +done + +if [ "${NEED_REAL}" -eq 1 ] && [ -z "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: Calibrating ${MODEL_NAME} (needed for distribution=real)" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs topk-kernel benchmark workload" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " TopK: ${TOPK_VAL}" +echo " Seq len: ${SEQ_LEN}" +echo " Distributions: ${DISTRIBUTION_LIST}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"distribution\"," + echo " \"axis_type\": \"kernel\"," + echo " \"model_name\": \"${MODEL_NAME}\"," + echo " \"topk_val\": ${TOPK_VAL}," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"seq_len\": ${SEQ_LEN}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +FIRST_CELL=1 +for DIST in ${DISTRIBUTION_LIST}; do + CELL_DIR="${SWEEP_DIR}/dist_${DIST}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: distribution=${DIST}" + echo "============================================================" + + AUTOTUNE_DIST_ARGS=() + BENCH_DIST_ARGS=() + if [ "${DIST}" = "real" ]; then + AUTOTUNE_DIST_ARGS=(--real-histograms "${REAL_HISTOGRAMS}") + BENCH_DIST_ARGS=(--real-histograms "${REAL_HISTOGRAMS}" --distributions real) + else + AUTOTUNE_DIST_ARGS=(--distributions "${DIST}") + BENCH_DIST_ARGS=(--distributions "${DIST}") + fi + + echo ">>> Autotuning hparams on dist=${DIST}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + "${AUTOTUNE_DIST_ARGS[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench on dist=${DIST}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + "${BENCH_DIST_ARGS[@]}" \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "distribution" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "topk_benchmark (kernel workload) ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/ablation_remap_function_topk_val.sh b/examples/ablation_remap_function_topk_val.sh new file mode 100644 index 00000000..4e60440a --- /dev/null +++ b/examples/ablation_remap_function_topk_val.sh @@ -0,0 +1,255 @@ +#!/usr/bin/env bash +# ============================================================ +# Ablation: Remap function vs. topk_val +# +# Sweeps TOPK_VAL and, for every cell, runs +# autotune -> remap-bench +# so the per-mode hyperparameter is freshly chosen by autotune +# for that topk_val (NOT hardcoded). Calibration runs once for +# the model. +# +# Mapping modes under test (matches the screenshot): +# 0 none, 3 power, 6 asinh, 7 log1p, 9 erf, 10 tanh, +# 11 subtract, 13 exp_stretch, 15 shift_pow2, 16 shift_pow3, +# 17 linear_steep +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +BLOCK_SIZE=1 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +BATCH_SIZE=4 +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +TOPK_VALS="512 1024 2048 4096" +REAL_HISTOGRAMS="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --gpu) GPU_ID="$2"; shift 2 ;; + --model-name) MODEL_NAME="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --batch-size) BATCH_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --topk-vals) TOPK_VALS="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +SWEEP_DIR="${RESULTS_DIR}/ablation_remap_topk_val_${MODEL_SLUG}_${TIMESTAMP}" +mkdir -p "${SWEEP_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Ablation: remap function vs topk_val" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Topk vals: ${TOPK_VALS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " GPU: ${GPU_ID}" +echo " Sweep dir: ${SWEEP_DIR}" +echo "============================================================" + +# ── Step 0: Calibrate once for this model ────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo ">>> Step 0: SKIPPED calibration (using ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo ">>> Step 0: Calibrating ${MODEL_NAME}" + STAGING_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_${TIMESTAMP}" + mkdir -p "${STAGING_DIR}" + CAL_TOPK_VAL=$(echo "${TOPK_VALS}" | tr ' ' '\n' | sort -n | tail -n 1) + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${CAL_TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${STAGING_DIR}" \ + 2>&1 | tee "${SWEEP_DIR}/step0_calibrate.log" + mv -f "${STAGING_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" +fi + +# ── Sweep ────────────────────────────────────────────────────── +SWEEP_INDEX="${SWEEP_DIR}/sweep_index.json" +{ + echo "{" + echo " \"axis_name\": \"topk_val\"," + echo " \"axis_type\": \"kernel\"," + echo " \"model_name\": \"${MODEL_NAME}\"," + echo " \"block_size\": ${BLOCK_SIZE}," + echo " \"mapping_modes\": [${MAPPING_MODES// /, }]," + echo " \"cells\": [" +} > "${SWEEP_INDEX}" + +FIRST_CELL=1 +for TOPK_VAL in ${TOPK_VALS}; do + MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) + SEQ_LEN=${MIN_SEQ_LEN} + if [ "${SEQ_LEN}" -lt 8192 ]; then SEQ_LEN=8192; fi + if [ "${SEQ_LEN}" -lt $(( TOPK_VAL * BLOCK_SIZE * 4 )) ]; then + SEQ_LEN=$(( TOPK_VAL * BLOCK_SIZE * 4 )) + fi + + CELL_DIR="${SWEEP_DIR}/topk${TOPK_VAL}" + mkdir -p "${CELL_DIR}" + AUTOTUNE_JSON="${CELL_DIR}/autotune_results.json" + REMAP_JSON="${CELL_DIR}/remap_bench.json" + + echo "" + echo "============================================================" + echo ">>> Cell: topk_val=${TOPK_VAL} seq_len=${SEQ_LEN}" + echo "============================================================" + + echo ">>> Autotuning hparams for topk_val=${TOPK_VAL}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step2_autotune.log" + + echo ">>> Remap bench for topk_val=${TOPK_VAL}" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --per-head-bench \ + --batch-sizes "${BATCH_SIZE}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + --autotune-json "${AUTOTUNE_JSON}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${CELL_DIR}/step3_remap_bench.log" + + if [ "${FIRST_CELL}" -eq 1 ]; then + FIRST_CELL=0 + else + echo " ," >> "${SWEEP_INDEX}" + fi + cat >> "${SWEEP_INDEX}" <> "${SWEEP_INDEX}" +echo "}" >> "${SWEEP_INDEX}" + +# ── Per-cell screenshot-style hparam summary ────────────────── +SELECTED_TXT="${SWEEP_DIR}/selected_hparams.txt" +PYTHONPATH="${SCRIPT_DIR}/.." python3 - "${SWEEP_INDEX}" "${SELECTED_TXT}" "topk_val" <<'PY' +import json, sys, os +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "benchmarks")) +try: + from autotune_topk_mapping import PARAM_NAME +except Exception: + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} +DISPLAY = {3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", 10: "Tanh", + 11: "Subtract", 13: "ExpStretch", 15: "ShiftPow2", + 16: "ShiftPow3", 17: "LinearSteep"} + +idx_path, out_path, axis_name = sys.argv[1], sys.argv[2], sys.argv[3] +with open(idx_path) as f: + idx = json.load(f) + +lines = [f"== Selected mapping functions (autotuned, {axis_name} sweep) =="] +for cell in idx["cells"]: + with open(cell["autotune_json"]) as f: + results = json.load(f) + best = {} + for r in results: + m = r["mode"] + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + parts = [] + for m in sorted(DISPLAY): + if m in best: + parts.append(f"{DISPLAY[m]}({PARAM_NAME.get(m,'p')}={best[m].get('param',0.0)})") + lines.append(f"[{axis_name}={cell['axis_value']}] " + " ".join(parts)) + +txt = "\n".join(lines) + "\n" +print(txt) +with open(out_path, "w") as f: + f.write(txt) +PY + +echo "" +echo "============================================================" +echo "topk_val ablation complete." +echo " Sweep dir: ${SWEEP_DIR}" +echo " Sweep index: ${SWEEP_INDEX}" +echo " Selected hparams: ${SELECTED_TXT}" +echo "Run analyze with:" +echo " python examples/analyze_ablation_remap.py --sweep-dir ${SWEEP_DIR}" +echo "============================================================" diff --git a/examples/analyze_ablation_remap.py b/examples/analyze_ablation_remap.py new file mode 100644 index 00000000..18b0a0b5 --- /dev/null +++ b/examples/analyze_ablation_remap.py @@ -0,0 +1,416 @@ +#!/usr/bin/env python3 +""" +Analyze remap-function ablation sweeps. + +Reads one or more sweep directories produced by + ablation_remap_function_block_size.sh + ablation_remap_function_topk_val.sh + ablation_remap_function_model.sh + ablation_remap_function_topk_benchmark.sh + +and emits, for each sweep: + - tidy CSV of every (axis_value, mapping_mode, distribution, head) row + - wide CSV tables: latency, speedup vs baseline, chosen hparam + - LaTeX version of the chosen-hparam table + - markdown summary including the screenshot-style "Selected mapping + functions" line per axis value + - matplotlib PDF plots: latency vs axis, speedup vs axis, threshold + bin size vs axis (one curve per mapping mode) + +Usage: + python examples/analyze_ablation_remap.py \ + --sweep-dir results/ablation_remap_block_size_ \ + [--sweep-dir results/ablation_remap_model_ ...] \ + --output-dir results/ablation_remap_analysis_ +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +import sys +from pathlib import Path +from typing import Any, Dict, List, Optional + +import numpy as np +import pandas as pd +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +# Pull mode metadata from the autotune script so we don't duplicate it. +SCRIPT_DIR = Path(__file__).resolve().parent +BENCH_DIR = SCRIPT_DIR.parent / "benchmarks" +sys.path.insert(0, str(BENCH_DIR)) +try: + from autotune_topk_mapping import MODE_NAMES, PARAM_NAME # type: ignore +except Exception: + MODE_NAMES = {0: "none", 3: "power", 4: "log", 6: "asinh", 7: "log1p", + 8: "trunc8", 9: "erf", 10: "tanh", 11: "subtract", + 13: "exp_stretch", 15: "shift_pow2", 16: "shift_pow3", + 17: "linear_steep"} + PARAM_NAME = {3: "p", 6: "beta", 7: "alpha", 9: "alpha", 10: "alpha", + 11: "pivot", 13: "alpha", 15: "pivot", 16: "pivot", 17: "k"} + +DISPLAY_NAME = { + 0: "None", 3: "Power", 6: "Asinh", 7: "Log1p", 9: "Erf", + 10: "Tanh", 11: "Subtract", 13: "ExpStretch", + 15: "ShiftPow2", 16: "ShiftPow3", 17: "LinearSteep", +} + + +# ---------- Loading ---------- + +def _load_json(path: str) -> Any: + with open(path) as f: + return json.load(f) + + +def _best_per_mode_from_autotune(autotune_results: List[dict]) -> Dict[int, dict]: + best: Dict[int, dict] = {} + for r in autotune_results: + m = int(r["mode"]) + if m not in best or r["latency_ms"] < best[m]["latency_ms"]: + best[m] = r + return best + + +def _flatten_remap_bench(remap_results: List[dict]) -> pd.DataFrame: + """Flatten bench_topk.py --remap-bench output into one row per + (cfg, mode_row). Drops per-head sub-rows; keeps head='all' so each + cell contributes a single point per (mapping_mode, distribution).""" + rows = [] + for cfg in remap_results: + if cfg.get("head", "all") != "all": + continue + baseline = cfg.get("baseline_ms") + for mr in cfg.get("modes", []): + mode = int(mr["mode"]) + rows.append({ + "distribution": cfg.get("distribution"), + "batch_size": cfg.get("batch_size"), + "num_kv_heads": cfg.get("num_kv_heads"), + "seq_len": cfg.get("seq_len"), + "topk_val": cfg.get("topk_val"), + "pages_per_seg": cfg.get("pages_per_seg"), + "mode": mode, + "mode_name": mr.get("mode_name", MODE_NAMES.get(mode, str(mode))), + "param_value": mr.get("power"), + "fused_ms": mr.get("fused_ms"), + "remap_ms": mr.get("remap_ms"), + "topk_after_remap_ms": mr.get("topk_after_remap_ms"), + "split_total_ms": mr.get("split_total_ms"), + "baseline_ms": baseline, + "threshold_bin_size_mean": mr.get("threshold_bin_size_mean"), + "threshold_bin_size_max": mr.get("threshold_bin_size_max"), + "refine_rounds_mean": mr.get("refine_rounds_mean"), + }) + return pd.DataFrame(rows) + + +def load_sweep(sweep_dir: Path) -> Dict[str, Any]: + idx_path = sweep_dir / "sweep_index.json" + if not idx_path.exists(): + raise FileNotFoundError(f"missing sweep_index.json in {sweep_dir}") + idx = _load_json(idx_path) + axis_name = idx["axis_name"] + + rows: List[pd.DataFrame] = [] + chosen_hparams: List[dict] = [] + for cell in idx["cells"]: + axis_value = cell["axis_value"] + autotune_results = _load_json(cell["autotune_json"]) + best = _best_per_mode_from_autotune(autotune_results) + for mode, r in best.items(): + chosen_hparams.append({ + "axis_value": axis_value, + "mode": int(mode), + "mode_name": r.get("mode_name", MODE_NAMES.get(int(mode), str(mode))), + "param_name": r.get("param_name") or PARAM_NAME.get(int(mode), "p"), + "param_value": r.get("param"), + "autotune_latency_ms": r.get("latency_ms"), + }) + + remap_results = _load_json(cell["remap_bench_json"]) + df = _flatten_remap_bench(remap_results) + df.insert(0, "axis_value", axis_value) + df.insert(0, "axis_name", axis_name) + rows.append(df) + + tidy = pd.concat(rows, ignore_index=True) if rows else pd.DataFrame() + chosen = pd.DataFrame(chosen_hparams) + return { + "axis_name": axis_name, + "axis_type": idx.get("axis_type", "kernel"), + "index": idx, + "tidy": tidy, + "chosen": chosen, + } + + +# ---------- Tables ---------- + +def _wide_latency(tidy: pd.DataFrame, axis_name: str, distribution: Optional[str] = None) -> pd.DataFrame: + df = tidy.copy() + if distribution is not None and "distribution" in df.columns: + df = df[df["distribution"] == distribution] + # Best fused latency per (axis_value, mode) — collapse over distribution + # if no filter was applied. + g = df.groupby(["axis_value", "mode", "mode_name"], dropna=False)["fused_ms"].min().reset_index() + wide = g.pivot(index="axis_value", columns="mode", values="fused_ms") + # Also pivot mode_name → label for column header. + return wide.rename(columns=lambda m: f"{m}:{MODE_NAMES.get(int(m), '?')}") + + +def _wide_baseline(tidy: pd.DataFrame, distribution: Optional[str] = None) -> pd.Series: + df = tidy.copy() + if distribution is not None and "distribution" in df.columns: + df = df[df["distribution"] == distribution] + return df.groupby("axis_value")["baseline_ms"].min() + + +def _wide_speedup(tidy: pd.DataFrame, axis_name: str, distribution: Optional[str] = None) -> pd.DataFrame: + lat = _wide_latency(tidy, axis_name, distribution=distribution) + base = _wide_baseline(tidy, distribution=distribution) + return lat.rdiv(base, axis=0) # baseline / fused + + +def _wide_chosen_hparam(chosen: pd.DataFrame) -> pd.DataFrame: + if chosen.empty: + return pd.DataFrame() + chosen = chosen.copy() + chosen["label"] = chosen.apply( + lambda r: f"{DISPLAY_NAME.get(int(r['mode']), r['mode_name'])}({r['param_name']}={r['param_value']})", + axis=1, + ) + wide = chosen.pivot(index="axis_value", columns="mode", values="label") + return wide.rename(columns=lambda m: f"{m}:{MODE_NAMES.get(int(m), '?')}") + + +def _df_to_latex(df: pd.DataFrame, caption: str, label: str) -> str: + if df.empty: + return f"% empty table for {label}\n" + try: + return df.to_latex( + float_format=lambda v: "" if pd.isna(v) else f"{v:.4f}", + na_rep="", + caption=caption, + label=label, + ) + except Exception: + return df.to_string() + + +# ---------- Plots ---------- + +def _axis_x(values: List[Any]) -> List[float]: + """Convert axis values (which may be strings or ints) to numeric x + coordinates. Strings are mapped to 0..N-1; numerics keep their value.""" + out = [] + for i, v in enumerate(values): + if isinstance(v, (int, float)): + out.append(float(v)) + else: + out.append(float(i)) + return out + + +def _plot_metric_vs_axis(tidy: pd.DataFrame, axis_name: str, metric: str, + out_path: Path, ylabel: str, title: str, + baseline_series: Optional[pd.Series] = None, + logy: bool = False) -> None: + if tidy.empty: + return + g = tidy.groupby(["axis_value", "mode", "mode_name"], dropna=False)[metric].min().reset_index() + axis_values = sorted(g["axis_value"].unique(), + key=lambda v: (not isinstance(v, (int, float)), v)) + x = _axis_x(axis_values) + + fig, ax = plt.subplots(figsize=(7, 4.5)) + cmap = plt.cm.get_cmap("tab10") + for i, mode in enumerate(sorted(g["mode"].unique())): + sub = g[g["mode"] == mode].set_index("axis_value").reindex(axis_values) + ax.plot(x, sub[metric].values, + marker="o", color=cmap(i % 10), + label=f"{mode}:{MODE_NAMES.get(int(mode), '?')}") + + if baseline_series is not None and not baseline_series.empty: + bx = baseline_series.reindex(axis_values).values + ax.plot(x, bx, "k--", linewidth=2, label="baseline (unmapped)") + + ax.set_xlabel(axis_name) + ax.set_ylabel(ylabel) + ax.set_title(title) + if logy: + ax.set_yscale("log") + if all(isinstance(v, (int, float)) for v in axis_values): + ax.set_xticks(x) + ax.set_xticklabels([str(v) for v in axis_values]) + else: + ax.set_xticks(x) + ax.set_xticklabels([str(v) for v in axis_values], rotation=20, ha="right") + ax.grid(True, alpha=0.3) + ax.legend(fontsize=7, ncol=2, loc="best") + fig.tight_layout() + fig.savefig(out_path) + plt.close(fig) + + +# ---------- Per-sweep emitters ---------- + +def emit_sweep(sweep: Dict[str, Any], out_root: Path) -> None: + axis_name = sweep["axis_name"] + out_dir = out_root / axis_name + out_dir.mkdir(parents=True, exist_ok=True) + + tidy: pd.DataFrame = sweep["tidy"] + chosen: pd.DataFrame = sweep["chosen"] + + if tidy.empty: + print(f"[{axis_name}] no data, skipping") + return + + tidy.to_csv(out_dir / "tidy.csv", index=False) + chosen.to_csv(out_dir / "chosen_hparams_long.csv", index=False) + + distributions = sorted([d for d in tidy["distribution"].dropna().unique()]) + + # Per-distribution wide tables + plots. + for dist in distributions + [None]: + suffix = f"_{dist}" if dist else "_all" + lat_wide = _wide_latency(tidy, axis_name, distribution=dist) + spd_wide = _wide_speedup(tidy, axis_name, distribution=dist) + base = _wide_baseline(tidy, distribution=dist) + + lat_wide.to_csv(out_dir / f"table_latency_ms{suffix}.csv") + spd_wide.to_csv(out_dir / f"table_speedup_vs_baseline{suffix}.csv") + base.to_frame("baseline_ms").to_csv(out_dir / f"table_baseline_ms{suffix}.csv") + + with open(out_dir / f"table_latency_ms{suffix}.tex", "w") as f: + f.write(_df_to_latex(lat_wide, + caption=f"Best fused-kernel latency (ms) on {axis_name} sweep ({dist or 'all dists'})", + label=f"tab:lat-{axis_name}{suffix}")) + with open(out_dir / f"table_speedup_vs_baseline{suffix}.tex", "w") as f: + f.write(_df_to_latex(spd_wide, + caption=f"Speedup over unmapped baseline on {axis_name} sweep ({dist or 'all dists'})", + label=f"tab:spd-{axis_name}{suffix}")) + + _plot_metric_vs_axis( + tidy if dist is None else tidy[tidy["distribution"] == dist], + axis_name, "fused_ms", + out_dir / f"plot_latency_vs_{axis_name}{suffix}.pdf", + ylabel="fused TopK kernel latency (ms)", + title=f"TopK kernel latency vs {axis_name} ({dist or 'all dists'})", + baseline_series=base, + ) + # Speedup plot. + spd_long = tidy.copy() + if dist: + spd_long = spd_long[spd_long["distribution"] == dist] + spd_long = spd_long.assign( + speedup=spd_long["baseline_ms"] / spd_long["fused_ms"] + ) + _plot_metric_vs_axis( + spd_long, axis_name, "speedup", + out_dir / f"plot_speedup_vs_{axis_name}{suffix}.pdf", + ylabel="speedup over unmapped baseline", + title=f"Speedup vs {axis_name} ({dist or 'all dists'})", + ) + # Threshold bin size diagnostic. + _plot_metric_vs_axis( + tidy if dist is None else tidy[tidy["distribution"] == dist], + axis_name, "threshold_bin_size_mean", + out_dir / f"plot_threshold_bin_size_vs_{axis_name}{suffix}.pdf", + ylabel="mean threshold-bin size (entries)", + title=f"Stage-1 threshold bin size vs {axis_name} ({dist or 'all dists'})", + ) + + # Chosen-hparam wide table (axis-independent of distribution: autotune + # picks one hparam per mode per axis cell). + chosen_wide = _wide_chosen_hparam(chosen) + chosen_wide.to_csv(out_dir / "table_chosen_hparams.csv") + with open(out_dir / "table_chosen_hparams.tex", "w") as f: + f.write(_df_to_latex(chosen_wide, + caption=f"Autotuned remap-function hyperparameters per {axis_name} cell", + label=f"tab:hparam-{axis_name}")) + + # Markdown summary. + md_lines: List[str] = [] + md_lines.append(f"# Ablation: remap function vs `{axis_name}`\n") + md_lines.append(f"Source: `{sweep['index'].get('cells', [{}])[0].get('cell_dir', '')}/...`\n") + + md_lines.append("\n## Selected mapping functions (autotuned)\n") + md_lines.append("```") + for v in chosen_wide.index.tolist(): + parts = [] + for col in chosen_wide.columns: + label = chosen_wide.loc[v, col] + if isinstance(label, str) and label: + parts.append(label) + md_lines.append(f"[{axis_name}={v}] " + " ".join(parts)) + md_lines.append("```\n") + + md_lines.append("\n## Latency (ms) — best fused, all distributions\n") + md_lines.append(_wide_latency(tidy, axis_name).to_markdown()) + md_lines.append("\n\n## Speedup over unmapped baseline\n") + md_lines.append(_wide_speedup(tidy, axis_name).to_markdown()) + md_lines.append("\n\n## Chosen hyperparameters\n") + md_lines.append(chosen_wide.to_markdown()) + md_lines.append("\n\n## Plots\n") + for p in sorted(out_dir.glob("plot_*.pdf")): + md_lines.append(f"- `{p.name}`") + + with open(out_dir / "summary.md", "w") as f: + f.write("\n".join(md_lines) + "\n") + + print(f"[{axis_name}] wrote artifacts to {out_dir}") + + +# ---------- Top-level ---------- + +def main() -> None: + ap = argparse.ArgumentParser(description="Aggregate ablation_remap_function_*.sh sweep outputs.") + ap.add_argument("--sweep-dir", action="append", required=True, + help="A sweep directory containing sweep_index.json. Repeat for multiple sweeps.") + ap.add_argument("--output-dir", type=str, required=True, + help="Where to write tables, plots, and summary.") + args = ap.parse_args() + + out_root = Path(args.output_dir) + out_root.mkdir(parents=True, exist_ok=True) + + sweeps: List[Dict[str, Any]] = [] + for sd in args.sweep_dir: + sweep = load_sweep(Path(sd)) + emit_sweep(sweep, out_root) + sweeps.append(sweep) + + # Cross-axis recommended hparams: for every mode, pick the param value + # that was selected most often across all axis cells of all sweeps. + all_chosen = pd.concat([s["chosen"] for s in sweeps if not s["chosen"].empty], + ignore_index=True) if sweeps else pd.DataFrame() + rec_lines: List[str] = [] + if not all_chosen.empty: + rec = (all_chosen.groupby(["mode", "mode_name", "param_name"])["param_value"] + .agg(lambda s: s.value_counts().idxmax()) + .reset_index().rename(columns={"param_value": "recommended"})) + rec.to_csv(out_root / "recommended_hparams.csv", index=False) + rec_lines.append("## Cross-axis recommended hparams (mode of selections)\n") + rec_lines.append(rec.to_markdown(index=False)) + + index_lines = ["# Remap-function ablation summary\n"] + for s in sweeps: + axis = s["axis_name"] + index_lines.append(f"- [`{axis}`]({axis}/summary.md)") + if rec_lines: + index_lines.append("") + index_lines.extend(rec_lines) + with open(out_root / "index.md", "w") as f: + f.write("\n".join(index_lines) + "\n") + print(f"[index] {out_root}/index.md") + + +if __name__ == "__main__": + main() diff --git a/examples/profile_in_docker.sh b/examples/profile_in_docker.sh new file mode 100755 index 00000000..a64606ff --- /dev/null +++ b/examples/profile_in_docker.sh @@ -0,0 +1,181 @@ +#!/usr/bin/env bash +# ============================================================ +# Run examples/profile_parallel_vs_fused.sh inside an NVIDIA +# CUDA devel container so we can enable profiling without +# touching the host's RmProfilingAdminOnly=1 setting. +# +# Key idea: +# - The container has `ncu` bundled with the CUDA toolkit. +# - --cap-add=SYS_ADMIN gives the container the capability +# CUPTI needs to access perf counters, so ncu works +# regardless of the host's nvidia-driver profiling restriction. +# - We mount the host's uv venv and the project, so there's +# no Python/pytorch install inside the container — the host +# venv's python is used directly. +# +# Image: +# Defaults to an NGC public CUDA devel image. For B200 (Blackwell / +# sm_100) you need CUDA ≥ 12.8 and ncu ≥ 2024.3; CUDA 13.0+ covers +# that. Override with NCU_IMAGE if you prefer a specific tag. +# +# Usage: +# bash examples/profile_in_docker.sh # defaults +# GPU=2 NUM_SPLITS=2 bash examples/profile_in_docker.sh +# NCU_IMAGE=nvcr.io/nvidia/pytorch:25.03-py3 bash examples/profile_in_docker.sh +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" + +# Host venv to reuse. uv venvs have their own python binary under +# $VENV/bin/python3.x that is glibc/libstdc++-compatible with the +# container when using NGC Ubuntu 22.04 / 24.04 images. +VENV_DIR="${VENV_DIR:-/home/zhuominc/xinrui_projects/uv_env/vortex}" + +# NGC CUDA devel image on Ubuntu. Has /usr/local/cuda/bin/ncu bundled. +# 13.0.1-devel-ubuntu22.04 is public (no NGC login needed), supports +# B200, and matches the host's CUDA 13.x driver ABI. +# +# Alternatives: +# nvcr.io/nvidia/cuda:13.0.1-devel-ubuntu24.04 # newer base +# nvcr.io/nvidia/pytorch:25.03-py3 # if you don't want to +# # reuse the host venv +# Host is Ubuntu 24.04 + Python 3.12 (the uv venv points to /usr/bin/python3.12). +# Match the container to that so the venv's symlinked python resolves to a +# compatible interpreter inside the container. +NCU_IMAGE="${NCU_IMAGE:-nvcr.io/nvidia/cuda:13.0.1-devel-ubuntu24.04}" + +# Pass-through env vars for the inner profile script. Defaults match +# examples/profile_parallel_vs_fused.sh. +GPU="${GPU:-7}" +EFF_BS="${EFF_BS:-1}" +NUM_SPLITS="${NUM_SPLITS:-2}" +POWER="${POWER:--1.0}" +WARMUP="${WARMUP:-20}" +ITERS="${ITERS:-1}" +SECTION_SET="${SECTION_SET:-full}" + +# Inside the container, these mount points give the profile script the +# same absolute paths it sees on the host (so the script doesn't need +# to be container-aware). +MOUNT_ROOT="/home/zhuominc/xinrui_projects" + +if [ ! -d "${VENV_DIR}" ]; then + echo "ERROR: VENV_DIR not found: ${VENV_DIR}" + echo " Set VENV_DIR=/path/to/venv or install the venv." + exit 1 +fi + +VENV_PY="$(ls "${VENV_DIR}"/bin/python* 2>/dev/null | head -1 || true)" +if [ -z "${VENV_PY}" ]; then + echo "ERROR: no python found under ${VENV_DIR}/bin/" + exit 1 +fi + +echo "============================================================" +echo "Docker-wrapped ncu profiling" +echo " image: ${NCU_IMAGE}" +echo " venv: ${VENV_DIR} (python=${VENV_PY##*/})" +echo " project: ${PROJECT_DIR}" +echo " GPU: ${GPU}" +echo " eff_bs: ${EFF_BS}" +echo " num_splits: ${NUM_SPLITS}" +echo " power: ${POWER}" +echo " warmup/iters: ${WARMUP}/${ITERS}" +echo " section set: ${SECTION_SET}" +echo "============================================================" + +# Pull the image up-front (so the output during the run isn't +# interleaved with pull progress). `|| true` — pull is optional; +# if the image is already local, docker run will use the cached copy. +docker pull "${NCU_IMAGE}" || true + +# Run the profile script inside the container. +# +# --gpus all : give the container access to all GPUs +# (CUDA_VISIBLE_DEVICES inside the script +# narrows it down to GPU ${GPU}). +# --cap-add=SYS_ADMIN : lets CUPTI access perf counters without +# touching host profiling restrictions. +# --security-opt seccomp=unconfined : CUPTI needs a few syscalls +# the default seccomp profile blocks. +# --network host : not strictly required, but keeps pip/uv +# network access working if you ever add +# pip-install steps. +# --user $(id -u):$(id -g) +# : write output files owned by your user, +# not root. +# -v /etc/passwd:/etc/passwd:ro -v /etc/group:/etc/group:ro +# : so the uid inside resolves to a real +# user (helps some tools, harmless otherwise). +# -v ${MOUNT_ROOT}:${MOUNT_ROOT} +# : mount the whole xinrui_projects tree so +# both the project and the venv are visible +# at their host paths. +# -e PYTHONPATH=... : add the venv's site-packages explicitly +# so `python3 -c 'import vortex_torch_C'` +# resolves even without activate. +# -e PATH=... : put the venv's bin ahead of /usr/local/cuda/bin +# so `python` is the venv python, and keep ncu +# reachable. +# When invoked via `sudo`, `id -u` returns 0 (root). Prefer SUDO_UID/ +# SUDO_GID so the final chown hands results back to the real user, +# not root. Fall back to the effective uid/gid otherwise. +HOST_UID="${SUDO_UID:-$(id -u)}" +HOST_GID="${SUDO_GID:-$(id -g)}" + +docker run --rm \ + --gpus all \ + --cap-add=SYS_ADMIN \ + --security-opt seccomp=unconfined \ + --network host \ + --ipc=host \ + -e DISPLAY="${DISPLAY:-}" \ + -v /tmp/.X11-unix:/tmp/.X11-unix \ + -v "${MOUNT_ROOT}:${MOUNT_ROOT}" \ + -w "${PROJECT_DIR}" \ + -e GPU="${GPU}" \ + -e EFF_BS="${EFF_BS}" \ + -e NUM_SPLITS="${NUM_SPLITS}" \ + -e POWER="${POWER}" \ + -e WARMUP="${WARMUP}" \ + -e ITERS="${ITERS}" \ + -e SECTION_SET="${SECTION_SET}" \ + -e NCU="/usr/local/cuda/bin/ncu" \ + -e HOST_UID="${HOST_UID}" \ + -e HOST_GID="${HOST_GID}" \ + -e PATH="${VENV_DIR}/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" \ + -e LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH:-}" \ + "${NCU_IMAGE}" \ + bash -lc ' + set -e + # Ubuntu 24.04 base may not ship python3.12 in the CUDA devel image. + # Install it idempotently; this is ~2s if missing and skipped otherwise. + if [ ! -x /usr/bin/python3.12 ]; then + echo "--- installing python3.12 in container ---" + export DEBIAN_FRONTEND=noninteractive + apt-get update -qq + apt-get install -y --no-install-recommends python3.12 >/dev/null + fi + echo "--- container environment ---" + echo "python: $(readlink -f "$(which python)") ($(python --version 2>&1))" + echo "ncu: $(which ncu)" + ncu --version 2>&1 | head -2 + nvidia-smi -L + python -c "import torch; print(\"torch: \", torch.__version__, \"cuda:\", torch.version.cuda)" + python -c "import vortex_torch_C; print(\"vortex_torch_C import OK\")" + echo "-----------------------------" + bash examples/profile_parallel_vs_fused.sh + # Hand output files back to the host user (we ran as root so apt + # could install python3.12). + chown -R "${HOST_UID}:${HOST_GID}" examples/results 2>/dev/null || true + ' + +echo "" +echo "============================================================" +echo "Docker profiling run complete." +echo "Reports are under: ${PROJECT_DIR}/examples/results/" +echo "(same path as the direct script — you own the files since we" +echo " ran the container as your uid)." +echo "============================================================" diff --git a/examples/profile_parallel_vs_fused_ncu.sh b/examples/profile_parallel_vs_fused_ncu.sh new file mode 100755 index 00000000..bf9baaf1 --- /dev/null +++ b/examples/profile_parallel_vs_fused_ncu.sh @@ -0,0 +1,277 @@ +#!/usr/bin/env bash +# ============================================================ +# Nsight Compute profiling script for the parallel vs fused +# TopK kernels. +# +# Profiles both: +# - TopKOutput_Fused_Kernel (csrc/topk_sglang.cu) +# - TopKOutput_Parallel_Kernel (csrc/topk_sglang_parallel.cu) +# +# With both remap functions the user cares about: +# - mode 15: MAPPING_SHIFT_POW2 +# - mode 16: MAPPING_SHIFT_POW3 +# +# And both configs: +# - A: topk=2048, pages_per_seg=32K (topk=2k from 32k) +# - B: topk=30, pages_per_seg=2K (topk=30 from 2k) +# +# Produces one .ncu-rep per (kernel × mode × config). Open with +# the Nsight Compute GUI for an interactive comparison, or dump on +# the CLI with `ncu --import .ncu-rep --page details`. +# +# Usage: +# bash examples/profile_parallel_vs_fused.sh # defaults +# GPU=4 EFF_BS=1 bash examples/profile_parallel_vs_fused.sh # small-batch case +# GPU=4 EFF_BS=32 bash examples/profile_parallel_vs_fused.sh # saturated case +# GPU=4 NUM_SPLITS=2 bash examples/profile_parallel_vs_fused.sh +# +# Requires `ncu` on PATH (part of the CUDA toolkit). On most systems +# accessing performance counters requires either: +# - root/sudo, or +# - `echo 1 | sudo tee /proc/driver/nvidia/params` (temporary), or +# - setting NVreg_RestrictProfilingToAdminUsers=0 in the nvidia driver. +# If ncu reports "ERR_NVGPUCTRPERM" you'll need one of the above. +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PY_DRIVER="${SCRIPT_DIR}/../benchmarks/profile_parallel_vs_fused.py" + +# ── Defaults ────────────────────────────────────────────────── +GPU=${GPU:-7} +EFF_BS=${EFF_BS:-1} # eff_batch_size = batch_size * num_kv_heads +NUM_SPLITS=${NUM_SPLITS:-2} # only used by the parallel kernel +POWER=${POWER:--1.0} # pivot p for shift_pow{2,3} +WARMUP=${WARMUP:-20} # matching-kernel warmup launches (ncu skips) +ITERS=${ITERS:-1} # matching-kernel profiled launches (ncu captures) +SECTION_SET=${SECTION_SET:-full} # ncu section set: "full", "basic", or named sections + +# Profiling robustness knobs for shared GPUs / CUDA 13 systems. +# --replay-mode application: re-run the entire process to collect each +# counter pass, instead of replaying individual +# kernels. Fixes "Failed to prepare kernel" on +# systems where kernel replay hits PMU conflicts. +# --clock-control none : don't try to lock GPU clocks (requires admin on +# shared GPUs; without this, "Unknown error on +# device 0" is common). +# --cache-control none : don't flush L1/L2 between passes (also needs +# admin on shared systems). +# Override with NCU_EXTRA_FLAGS="..." if you need a different combination. +NCU_EXTRA_FLAGS=${NCU_EXTRA_FLAGS:-"--replay-mode application --clock-control none --cache-control none"} + +# DIAG=1 bash profile_parallel_vs_fused.sh → run one tiny ncu probe to +# verify profiling works before doing the full sweep. +DIAG=${DIAG:-0} + +# ── ncu command ─────────────────────────────────────────────── +NCU=${NCU:-ncu} +command -v "${NCU}" >/dev/null 2>&1 || { + echo "ERROR: '${NCU}' not found on PATH. Install Nsight Compute (part of CUDA Toolkit)" + echo " or set NCU=/path/to/ncu and re-run." + exit 1 +} + +# The templated kernels end up with mangled names like +# _Z25TopKOutput_Fused_KernelI13__nv_bfloat16ILi15EEEvPKT_... +# ncu supports --kernel-name regex: which matches on the +# demangled signature. Using "TopKOutput_Fused_Kernel" and +# "TopKOutput_Parallel_Kernel" as the regex selects all template +# instantiations of each kernel but nothing else. +FUSED_REGEX="regex:TopKOutput_Fused_Kernel" +PARALLEL_REGEX="regex:TopKOutput_Parallel_Kernel" + +# ── Output dir ──────────────────────────────────────────────── +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_DIR="${SCRIPT_DIR}/results/ncu_parallel_vs_fused_${TIMESTAMP}" +mkdir -p "${OUT_DIR}" + +echo "============================================================" +echo "Nsight Compute profile: parallel vs fused TopK" +echo " GPU: ${GPU}" +echo " eff_bs: ${EFF_BS}" +echo " num_splits: ${NUM_SPLITS} (parallel kernel only)" +echo " power (p): ${POWER} (for shift_pow{2,3})" +echo " warmup: ${WARMUP} (matching-kernel launches skipped by ncu)" +echo " iters: ${ITERS} (matching-kernel launches captured)" +echo " sections: --set ${SECTION_SET}" +echo " extra ncu flags:${NCU_EXTRA_FLAGS}" +echo " output dir: ${OUT_DIR}" +echo "============================================================" + +# ── Diagnostic probe ───────────────────────────────────────── +# Verifies that ncu can attach and collect at least one section on +# this GPU before we burn time on the full sweep. Uses --set basic +# which is the cheapest section set. If this fails, see the +# TROUBLESHOOTING block that the script prints on error. +run_diag() { + echo "" + echo ">>> Diagnostic probe: can ncu attach at all?" + local out="${OUT_DIR}/diag.ncu-rep" + set +e + CUDA_VISIBLE_DEVICES="${GPU}" "${NCU}" \ + --force-overwrite \ + --target-processes all \ + --kernel-name "${FUSED_REGEX}" \ + --launch-skip "${WARMUP}" \ + --launch-count 1 \ + --set basic \ + ${NCU_EXTRA_FLAGS} \ + --export "${out}" \ + python "${PY_DRIVER}" \ + --config A --eff-bs 1 --mode 15 --power "${POWER}" \ + --num-splits "${NUM_SPLITS}" --kernel fused \ + --warmup "${WARMUP}" --iters 1 + local rc=$? + set -e + if [ ${rc} -ne 0 ]; then + cat <<'EOF' + +============================================================ +TROUBLESHOOTING "Failed to prepare kernel for profiling" +============================================================ + 1) Is another process using GPU ${GPU}? Check: + nvidia-smi + If yes, pick an idle GPU: + GPU=0 bash examples/profile_parallel_vs_fused.sh + + 2) Perf counters may be locked to admin. Try as root: + sudo -E bash examples/profile_parallel_vs_fused.sh + + Or permanently unlock (admin, persists until reboot): + sudo sh -c 'echo 1 > /proc/driver/nvidia/params' + + Or permanently in the driver (needs reboot): + Add NVreg_RestrictProfilingToAdminUsers=0 to + /etc/modprobe.d/nvidia.conf + + 3) MPS or another profiler (CUPTI, Nsight Systems, etc.) + may be running. Kill with: + echo quit | nvidia-cuda-mps-control + and verify nothing else is profiling. + + 4) On H100 with MIG: profiling across MIG slices is + restricted. Use a full-device GPU. + + 5) Try a smaller ncu configuration first: + NCU_EXTRA_FLAGS="--replay-mode application --clock-control none --cache-control none --metrics sm__cycles_elapsed.avg" \ + bash examples/profile_parallel_vs_fused.sh + + 6) CUDA 13.2 vs PyTorch-13.0 mismatch is sometimes flagged + by ncu. Update ncu to match CUDA 13.2, or use the ncu + shipped with CUDA 13.2: + NCU=/usr/local/cuda-13.2/bin/ncu bash ... + +============================================================ +EOF + echo "Diagnostic probe failed (exit ${rc}). See troubleshooting above." + exit ${rc} + fi + echo ">>> Diagnostic probe OK. Proceeding with full sweep." +} + +if [ "${DIAG}" = "1" ]; then + run_diag + exit 0 +fi + +# Always run a cheap probe first so full-sweep failures are caught early +# before we've spent minutes on the heavy --set full passes. +run_diag + +# ── Helper: run one ncu profile ────────────────────────────── +# tag : name used for the output file +# kernel : "fused" or "parallel" (drives Python driver dispatch) +# regex : ncu --kernel-name filter +# config : "A" or "B" +# mode : 15 or 16 +run_ncu() { + local tag="$1" + local kernel="$2" + local regex="$3" + local config="$4" + local mode="$5" + + local out="${ + + + + }/${tag}.ncu-rep" + + echo "" + echo ">>> ${tag}" + + # --launch-skip/--launch-count count ONLY kernels matching + # --kernel-name, so setup kernels (torch.randn, etc.) don't + # pollute the offsets. With --launch-skip=${WARMUP} and the + # Python driver doing ${WARMUP} warmup + ${ITERS} profiled + # calls, ncu captures exactly the profiled ones. + CUDA_VISIBLE_DEVICES="${GPU}" "${NCU}" \ + --force-overwrite \ + --target-processes all \ + --kernel-name "${regex}" \ + --launch-skip "${WARMUP}" \ + --launch-count "${ITERS}" \ + --set "${SECTION_SET}" \ + ${NCU_EXTRA_FLAGS} \ + --export "${out}" \ + python "${PY_DRIVER}" \ + --config "${config}" \ + --eff-bs "${EFF_BS}" \ + --mode "${mode}" \ + --power "${POWER}" \ + --num-splits "${NUM_SPLITS}" \ + --kernel "${kernel}" \ + --warmup "${WARMUP}" \ + --iters "${ITERS}" + + echo " report: ${out}" +} + +# ── Sweep ──────────────────────────────────────────────────── +for MODE in 15 16; do + if [ "${MODE}" -eq 15 ]; then MODE_TAG="SP2"; else MODE_TAG="SP3"; fi + for CONFIG in A B; do + run_ncu "fused_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}" \ + "fused" "${FUSED_REGEX}" "${CONFIG}" "${MODE}" + run_ncu "parallel_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}_ns${NUM_SPLITS}" \ + "parallel" "${PARALLEL_REGEX}" "${CONFIG}" "${MODE}" + done +done + +echo "" +echo "============================================================" +echo "All profiles done. Reports saved under:" +echo " ${OUT_DIR}" +echo "" +echo "Interactive analysis (recommended):" +echo " ncu-ui ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.ncu-rep" +echo "" +echo "CLI summary, one kernel at a time:" +echo " ncu --import ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.ncu-rep --page details" +echo "" +echo "Side-by-side diff (CLI):" +echo " ncu --import ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.ncu-rep \\" +echo " --import ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.ncu-rep \\" +echo " --page details --csv > ${OUT_DIR}/compare_SP2_cfgA.csv" +echo "" +echo "What to look at (to pinpoint the overhead vs fused):" +echo " * Section 'GPU Speed Of Light Throughput'" +echo " → SM %, Memory %, which one is the bound?" +echo " * Section 'Launch Statistics'" +echo " → Grid/Block size, Dynamic Shared Mem per block" +echo " * Section 'Occupancy'" +echo " → Theoretical vs achieved; limit (smem / regs / blocks/SM)" +echo " * Section 'Warp State Statistics'" +echo " → Stall breakdown: Stall Barrier (__syncthreads/__threadfence)," +echo " Stall Long Scoreboard (global memory), Stall Short Scoreboard" +echo " (smem/atomic)" +echo " * Section 'Memory Workload Analysis'" +echo " → L2/Device throughput, atomic traffic, smem bank conflicts" +echo " * Section 'Compute Workload Analysis'" +echo " → Pipe utilisation (FMA / ALU / FP64)" +echo "" +echo "Likely suspects for the parallel-vs-fused gap:" +echo " - Occupancy limited by the large dynamic smem (kSmem + chunk_bytes)" +echo " - Stall Barrier dominating due to the __threadfence before atomicInc" +echo " - Phase 1 CTAs repeat Stage-2 refinement that fused does only once" +echo " → visible as 'Pipe Utilisation ALU / Special' for integer radix ops" +echo "============================================================" diff --git a/examples/profile_parallel_vs_fused_nsys.sh b/examples/profile_parallel_vs_fused_nsys.sh new file mode 100755 index 00000000..3d64519a --- /dev/null +++ b/examples/profile_parallel_vs_fused_nsys.sh @@ -0,0 +1,211 @@ +#!/usr/bin/env bash +# ============================================================ +# Nsight Systems (nsys) profiling — timeline view of the parallel +# vs fused TopK kernels. +# +# Why nsys and not ncu here: +# ncu needs SM-level perf counters (sm__*), which on this box are +# gated by the nvidia driver's RmProfilingAdminOnly flag — and we +# have no sudo. nsys uses CUPTI API/activity tracing and kernel +# timing, which do NOT require admin. That's enough to answer the +# "where does the 6-8us overhead come from" question, because we +# get per-kernel durations, gaps on the stream, memcpy/memset +# traffic, and NVTX range timing. +# +# Profiles both: +# - TopKOutput_Fused_Kernel (csrc/topk_sglang.cu) +# - TopKOutput_Parallel_Kernel (csrc/topk_sglang_parallel.cu) +# +# For each of mode 15 (SHIFT_POW2), mode 16 (SHIFT_POW3) and both +# configs A (topk=2048 pages=32K) and B (topk=30 pages=2K). +# +# Produces one .nsys-rep per (kernel × mode × config). Open with: +# nsys-ui .nsys-rep +# or dump CLI summaries with: +# nsys stats .nsys-rep +# +# Usage: +# bash examples/profile_parallel_vs_fused_nsys.sh # defaults +# GPU=7 NUM_SPLITS=2 bash examples/profile_parallel_vs_fused_nsys.sh +# ITERS=50 bash examples/profile_parallel_vs_fused_nsys.sh # more samples +# ============================================================ +set -euo pipefail +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PY_DRIVER="${SCRIPT_DIR}/../benchmarks/profile_parallel_vs_fused.py" + +# ── Defaults ────────────────────────────────────────────────── +GPU=${GPU:-7} +EFF_BS=${EFF_BS:-1} +NUM_SPLITS=${NUM_SPLITS:-2} +POWER=${POWER:--1.0} +WARMUP=${WARMUP:-20} +# For nsys we want *many* iterations so the per-kernel timing is +# statistically meaningful and the timeline is readable. +ITERS=${ITERS:-50} + +# Prefer the CUDA-13 toolchain's nsys (matches the torch CUDA ABI). +NSYS=${NSYS:-$(command -v nsys || echo /usr/local/cuda/bin/nsys)} +if [ ! -x "${NSYS}" ]; then + echo "ERROR: nsys not found. Tried: ${NSYS}" + echo " Set NSYS=/path/to/nsys manually." + exit 1 +fi + +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +OUT_DIR="${SCRIPT_DIR}/results/nsys_parallel_vs_fused_${TIMESTAMP}" +mkdir -p "${OUT_DIR}" + +# nsys writes intermediate files under $TMPDIR/nvidia/nsight_systems. +# On shared systems /tmp/nvidia is often owned by another user who +# created it first, and we can't write there. Redirect to a +# user-writable cache dir. +export TMPDIR="${TMPDIR:-${HOME}/.cache/nsys_tmp}" +mkdir -p "${TMPDIR}" + +echo "============================================================" +echo "Nsight Systems profile: parallel vs fused TopK" +echo " GPU: ${GPU}" +echo " eff_bs: ${EFF_BS}" +echo " num_splits: ${NUM_SPLITS}" +echo " power (p): ${POWER}" +echo " warmup: ${WARMUP}" +echo " iters: ${ITERS} (profiled launches)" +echo " nsys binary: ${NSYS}" +echo " output dir: ${OUT_DIR}" +echo "============================================================" +"${NSYS}" --version 2>&1 | head -2 + +# ── Helper: run one nsys profile ───────────────────────────── +run_nsys() { + local tag="$1" + local kernel="$2" + local config="$3" + local mode="$4" + + local out="${OUT_DIR}/${tag}" + + echo "" + echo ">>> ${tag}" + + # --trace cuda,nvtx : CUDA API/runtime + NVTX ranges. NVTX + # stays on so the timeline still shows + # where the profiled region begins. + # --sample none / --cpuctxsw none: skip CPU callstack sampling and + # context-switch tracing — both admin-gated + # on this box and we don't need them. + # --cuda-memory-usage true: log cudaMalloc/cudaFree/cudaMemset so we + # can see if at::empty / at::zeros costs + # anything on the hot path. + # + # Capture-range flags intentionally OMITTED. On some nsys builds + # --capture-range=nvtx silently yields "No reports were generated" + # when the ranges don't line up exactly; profiling the whole run + # is more robust and the warmup is easy to filter out later + # (NVTX range "profile-*" tags the profiled region in nsys stats). + CUDA_VISIBLE_DEVICES="${GPU}" "${NSYS}" profile \ + --output "${out}" \ + --force-overwrite true \ + --trace cuda,nvtx \ + --sample none \ + --cpuctxsw none \ + --cuda-memory-usage true \ + python "${PY_DRIVER}" \ + --config "${config}" \ + --eff-bs "${EFF_BS}" \ + --mode "${mode}" \ + --power "${POWER}" \ + --num-splits "${NUM_SPLITS}" \ + --kernel "${kernel}" \ + --warmup "${WARMUP}" \ + --iters "${ITERS}" + + echo " report: ${out}.nsys-rep" +} + +# ── Sweep ──────────────────────────────────────────────────── +for MODE in 15 16; do + if [ "${MODE}" -eq 15 ]; then MODE_TAG="SP2"; else MODE_TAG="SP3"; fi + for CONFIG in A B; do + run_nsys "fused_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}" \ + "fused" "${CONFIG}" "${MODE}" + run_nsys "parallel_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}_ns${NUM_SPLITS}" \ + "parallel" "${CONFIG}" "${MODE}" + done +done + +# ── Auto-dump CLI summaries for every report ───────────────── +# `nsys stats` produces text tables that are immediately readable +# and answer most "where did the time go" questions without needing +# the GUI. We dump the most useful ones for every report and stash +# them alongside. +echo "" +echo "============================================================" +echo "Dumping text summaries ('nsys stats') for every report..." +echo "============================================================" +for rep in "${OUT_DIR}"/*.nsys-rep; do + name="$(basename "${rep}" .nsys-rep)" + echo "" + echo ">>> summary for ${name}" + summary="${OUT_DIR}/${name}.summary.txt" + { + echo "### ${name}" + echo "" + echo "## cuda_api_sum: CUDA runtime API call distribution" + echo "## (count, avg, med, min, max of cudaLaunchKernel / cudaMalloc / etc.)" + "${NSYS}" stats --report cuda_api_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## cuda_gpu_kern_sum: per-kernel GPU duration stats" + echo "## (mean/median/std/min/max duration per kernel name, with instance count)" + "${NSYS}" stats --report cuda_gpu_kern_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## cuda_gpu_mem_size_sum: memcpy / memset by size" + echo "## (expect 0 memset entries for parallel — no at::zeros on the hot path)" + "${NSYS}" stats --report cuda_gpu_mem_size_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## cuda_gpu_mem_time_sum: memcpy / memset by time" + "${NSYS}" stats --report cuda_gpu_mem_time_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## cuda_kern_exec_sum: kernel launch→exec latency" + echo "## (host-side cudaLaunchKernel cost separated from GPU exec cost)" + "${NSYS}" stats --report cuda_kern_exec_sum --format table "${rep}" 2>&1 || true + echo "" + echo "## nvtx_pushpop_sum: NVTX ranges (the 'profile-*' wrapped region)" + "${NSYS}" stats --report nvtx_pushpop_sum --format table "${rep}" 2>&1 || true + } > "${summary}" 2>&1 + echo " saved: ${summary}" +done + +echo "" +echo "============================================================" +echo "Reports saved to: ${OUT_DIR}" +echo "" +echo "Quick read — compare fused vs parallel summaries side-by-side:" +echo "" +echo " diff -y --width=200 \\" +echo " ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.summary.txt \\" +echo " ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.summary.txt \\" +echo " | less" +echo "" +echo "Interactive timeline (if you have X11/SSH forwarding):" +echo " nsys-ui ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.nsys-rep" +echo "" +echo "What to look for (to nail the overhead vs fused):" +echo " * 'cuda_gpu_kern_sum' mean duration for each kernel" +echo " → fused is one kernel × (WARMUP+ITERS), parallel is one kernel × (WARMUP+ITERS)" +echo " (single-kernel design). Mean duration difference = the GPU work" +echo " gap (Stage-1 savings minus merge cost)." +echo " * 'cuda_api_sum' cudaLaunchKernel / cudaMalloc / cudaFree counts" +echo " → if parallel shows more launches than fused, there's an unexpected" +echo " extra kernel. Also watch the time spent in cudaLaunchKernel." +echo " * 'cuda_gpu_mem_size_sum' cudaMemset entries" +echo " → should be zero for parallel now (__device__ counter removed" +echo " at::zeros). Any memset here IS overhead we need to explain." +echo " * 'cuda_kern_exec_sum'" +echo " → separates host-side cudaLaunchKernel latency from GPU kernel time." +echo " * 'nvtx_pushpop_sum' profile-* range duration / ${ITERS}" +echo " → wall-clock per-call including CPU-side overhead." +echo "" +echo "Timeline view (nsys-ui) additionally shows *gaps* between kernels" +echo "on the GPU stream — the cost of __threadfence + atomicInc barrier" +echo "shows up as a visible pause between Phase-1 work and the merge." +echo "============================================================" diff --git a/examples/remap_function_bench_topk_parallel.sh b/examples/remap_function_bench_topk_parallel.sh new file mode 100755 index 00000000..df33f2cf --- /dev/null +++ b/examples/remap_function_bench_topk_parallel.sh @@ -0,0 +1,245 @@ +#!/usr/bin/env bash +# ============================================================ +# Remap Function Benchmark — Parallel TopK variant. +# +# Wraps bench_topk.py --remap-bench with --bench-parallel so the +# output table includes a "par_ms" column comparing the split+merge +# kernel (topk_output_sglang_parallel) against the single-CTA +# fused kernel. Also sweeps batch size and num_splits so the +# occupancy-vs-merge-overhead curve is visible. +# +# Pipeline mirrors remap_function_bench_topk2028.sh: +# Step 1 — calibrate (can be skipped with --real-histograms) +# Step 2 — autotune per-mode hparams by fused-kernel latency +# Step 3 — remap bench, looped over NUM_SPLITS_SWEEP values +# +# Usage: +# bash remap_function_bench_topk_parallel.sh --gpu 4 +# +# # Explicit batch-size sweep: +# bash remap_function_bench_topk_parallel.sh --gpu 4 \ +# --batch-sizes "1 2 4 8" --num-splits-sweep "auto 2 4 8" +# ============================================================ +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +BENCH_DIR="${SCRIPT_DIR}/../benchmarks" + +# ── Defaults ────────────────────────────────────────────────── +GPU_ID=4 +MODEL_NAME="Qwen/Qwen3-1.7B" +TOPK_VAL=2048 +MEM=0.7 +MAX_TOTAL_TOKENS=64768 +MIN_FREE_DISK_GB=22 +ALGO="block_sparse_attention" +SAMPLE_STRIDE=1 +SEQ_LEN=32768 +BLOCK_SIZE=1 +BATCH_SIZES="1 2 4 8 16" +NUM_KV_HEADS=8 +DISTRIBUTIONS="normal bucket_uniform" +# Modes excluding 1 (LUT_CDF) and 2 (Quantile) which are discarded. +MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19" +MAPPING_HPARAM=0.5 +REPEAT=100 +WARMUP=20 +# "auto" lets bench_topk.py pick via sqrt(pages/topk). Explicit ints +# pin a split count for A/B comparisons. +NUM_SPLITS_SWEEP="auto 2 4 8" +REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy" +SKIP_AUTOTUNE=0 +PINNED_AUTOTUNE_JSON="" + +# ── Parse arguments ─────────────────────────────────────────── +while [[ $# -gt 0 ]]; do + case "$1" in + --model-name) MODEL_NAME="$2"; shift 2 ;; + --topk-val) TOPK_VAL="$2"; shift 2 ;; + --mem) MEM="$2"; shift 2 ;; + --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; + --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --algo) ALGO="$2"; shift 2 ;; + --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; + --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; + --seq-len) SEQ_LEN="$2"; shift 2 ;; + --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; + --batch-sizes) BATCH_SIZES="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --modes) MAPPING_MODES="$2"; shift 2 ;; + --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --num-splits-sweep) NUM_SPLITS_SWEEP="$2"; shift 2 ;; + --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; + --pinned-autotune-json) PINNED_AUTOTUNE_JSON="$2"; SKIP_AUTOTUNE=1; shift 2 ;; + *) echo "Unknown option: $1"; exit 1 ;; + esac +done + +export CUDA_VISIBLE_DEVICES="${GPU_ID}" +export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" +export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" + +if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then + if [ -x /usr/local/cuda/bin/nvcc ]; then + export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" + export PATH="${CUDA_HOME}/bin:${PATH}" + export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" + elif command -v nvcc >/dev/null 2>&1; then + export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" + fi +fi + +MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) +if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then + echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." + echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" + exit 1 +fi + +RESULTS_DIR="${SCRIPT_DIR}/results" +mkdir -p "${RESULTS_DIR}" +TIMESTAMP=$(date +%Y%m%d_%H%M%S) +MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" +RUN_DIR="${RESULTS_DIR}/parallel_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +mkdir -p "${RUN_DIR}" + +CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" +MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" +DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" +mkdir -p "${CALIBRATION_BASE}" + +if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then + REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" +fi + +echo "============================================================" +echo "Remap Function Benchmark (Parallel TopK variant)" +echo " Model: ${MODEL_NAME}" +echo " Algorithm: ${ALGO}" +echo " TopK: ${TOPK_VAL}" +echo " Block size: ${BLOCK_SIZE}" +echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" +echo " Batch sizes: ${BATCH_SIZES}" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " Mapping modes: ${MAPPING_MODES}" +echo " num_splits sweep:${NUM_SPLITS_SWEEP}" +echo " GPU: ${GPU_ID}" +echo " Real histograms: ${REAL_HISTOGRAMS:-}" +echo " Output: ${RUN_DIR}" +echo "============================================================" + +# ── Step 1: Calibrate ──────────────────────────────────────── +if [ -n "${REAL_HISTOGRAMS}" ]; then + echo "" + echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" + REAL_HIST_PATH="${REAL_HISTOGRAMS}" +else + echo "" + echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" + CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" + mkdir -p "${CALIBRATION_DIR}" + python "${BENCH_DIR}/calibrate_topk.py" \ + --model-name "${MODEL_NAME}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --mem "${MEM}" \ + --max-total-tokens "${MAX_TOTAL_TOKENS}" \ + --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ + --vortex-module-name "${ALGO}" \ + --output-dir "${CALIBRATION_DIR}" \ + 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" + mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" + REAL_HIST_PATH="${DEFAULT_REAL_HIST}" + echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" +fi + +# ── Step 2: Autotune ───────────────────────────────────────── +AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" +if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then + echo "" + if [ -n "${PINNED_AUTOTUNE_JSON}" ]; then + echo ">>> Step 2: SKIPPED (pinned hparams from ${PINNED_AUTOTUNE_JSON})" + AUTOTUNE_ARGS="--autotune-json ${PINNED_AUTOTUNE_JSON}" + else + echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" + AUTOTUNE_ARGS="" + fi +else + echo "" + echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" + # Autotune on the largest batch size so the picked hparam matches realistic + # decode conditions; the hparam itself is largely batch-invariant. + FIRST_BS="$(echo ${BATCH_SIZES} | awk '{print $NF}')" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ + --batch-size "${FIRST_BS}" \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-len "${SEQ_LEN}" \ + --topk-val "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --real-histograms "${REAL_HIST_PATH}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --collect-stats \ + --output-json "${AUTOTUNE_JSON}" \ + 2>&1 | tee "${RUN_DIR}/step2_autotune.log" + echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" + AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" +fi + +# ── Step 3: Remap + Parallel bench, sweeping num_splits ────── +echo "" +echo ">>> Step 3: Timing baseline / fused / parallel with num_splits sweep" + +for NS in ${NUM_SPLITS_SWEEP}; do + if [ "${NS}" = "auto" ]; then + NS_ARG="--num-splits -1" + NS_TAG="auto" + else + NS_ARG="--num-splits ${NS}" + NS_TAG="ns${NS}" + fi + REMAP_JSON="${RUN_DIR}/remap_bench_${NS_TAG}.json" + LOG="${RUN_DIR}/step3_remap_bench_${NS_TAG}.log" + BENCH_EXTRA=() + [ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") + echo "" + echo "--- num_splits=${NS_TAG} ---" + PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --bench-parallel \ + ${NS_ARG} \ + --batch-sizes ${BATCH_SIZES} \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens "${SEQ_LEN}" \ + --topk-vals "${TOPK_VAL}" \ + --page-size "${BLOCK_SIZE}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes ${MAPPING_MODES} \ + --mapping-hparam "${MAPPING_HPARAM}" \ + ${AUTOTUNE_ARGS} \ + "${BENCH_EXTRA[@]}" \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${REMAP_JSON}" \ + 2>&1 | tee "${LOG}" + echo ">>> num_splits=${NS_TAG}: JSON -> ${REMAP_JSON}" +done + +# ── Summary ─────────────────────────────────────────────────── +echo "" +echo "============================================================" +echo "Parallel TopK Benchmark Complete" +echo " Model: ${MODEL_NAME}" +echo " Block size: ${BLOCK_SIZE}" +echo " Batch sizes: ${BATCH_SIZES}" +echo " num_splits sweep: ${NUM_SPLITS_SWEEP}" +echo " All outputs in: ${RUN_DIR}/" +echo " autotune_results.json — latency-ranked mapping hparams" +echo " remap_bench_.json — per-config latencies including par_ms" +echo " step{1,2,3}_*.log — pipeline logs" +echo "============================================================" diff --git a/examples/verify_algo.py b/examples/verify_algo.py index a78f1e69..dacba655 100644 --- a/examples/verify_algo.py +++ b/examples/verify_algo.py @@ -303,10 +303,12 @@ def parse_args(): "--topk-mapping-mode", type=int, default=0, - choices=[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13], + choices=[0, 1, 2, 3, 4, 6, 7, 8, 9, 10, 11, 13, 15, 16, 17, 18, 19, 20], help='TopK mapping mode for sglang_fused: 0=none, 1=lut_cdf (calibrated), ' '2=quantile (calibrated), 3=power, 4=log, 6=asinh, 7=log1p, 8=trunc8, ' - '9=erf, 10=tanh, 11=subtract, 13=exp_stretch (default: 0).', + '9=erf, 10=tanh, 11=subtract, 13=exp_stretch, 15=shift_pow2, ' + '16=shift_pow3, 17=linear_steep, 18=half_square, 19=half_cube, ' + '20=dense_mant (default: 0).', ) parser.add_argument( @@ -326,11 +328,20 @@ def parse_args(): "Use multiple values to run several benchmarks sequentially (default: amc23).", ) + parser.add_argument( + "--output-json", + type=str, + default=None, + help="Optional path. When set, a JSON list of per-benchmark summary dicts is " + "dumped here after all benchmarks finish. Used by the ablation wrappers.", + ) + return parser.parse_args() if __name__ == "__main__": args = parse_args() + all_summaries = [] for bench_name in args.benchmark: if bench_name not in BENCHMARK_REGISTRY: print(f"WARNING: Unknown benchmark '{bench_name}', skipping. Available: {list(BENCHMARK_REGISTRY.keys())}") @@ -353,6 +364,20 @@ def parse_args(): benchmark=bench_name, ) summary["benchmark"] = bench_name + summary["model_name"] = args.model_name + summary["topk_val"] = args.topk_val + summary["page_size"] = args.page_size + summary["topk_type"] = args.topk_type + summary["topk_mapping_mode"] = args.topk_mapping_mode + summary["topk_mapping_hparam"] = args.topk_mapping_hparam + summary["full_attention"] = bool(args.full_attention) print(summary) + all_summaries.append(summary) + + if args.output_json: + os.makedirs(os.path.dirname(os.path.abspath(args.output_json)) or ".", exist_ok=True) + with open(args.output_json, "w") as f: + json.dump(all_summaries, f, indent=2) + print(f"\n[verify_algo] summary JSON written to {args.output_json}") exit(0) \ No newline at end of file diff --git a/setup.py b/setup.py index c9731815..8b496610 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ 'csrc/topk_sglang.cu', 'csrc/topk_sglang_profile.cu', 'csrc/topk_sglang_ori.cu', + 'csrc/topk_sglang_parallel.cu', ], include_dirs=['csrc'], extra_compile_args={ From 28ccd8578cfa1a661db763d0c4e8c9c7758516e0 Mon Sep 17 00:00:00 2001 From: UED Date: Mon, 20 Apr 2026 14:55:44 -0400 Subject: [PATCH 23/24] Add new TopK cluster and fast merge kernels, update profiling and benchmarking - Introduced and for improved TopK performance on Hopper architecture. - Updated to include the new source files for the cluster and fast merge kernels. - Enhanced to support the new kernels and added logic for automatic split determination. - Retired outdated mapping modes (LUT_CDF, QUANTILE, DENSE_MANT) from the kernel implementations and profiling scripts. - Modified example scripts to reflect changes in benchmarking configurations and GPU utilization improvements. --- benchmarks/bench_topk.py | 120 +- csrc/register.cc | 21 +- csrc/register.h | 39 +- csrc/topk_mapping.cuh | 47 +- csrc/topk_sglang.cu | 4 +- csrc/topk_sglang_cluster.cu | 684 ++++++++++ csrc/topk_sglang_parallel.cu | 1184 +++++++---------- csrc/topk_sglang_profile.cu | 25 +- .../remap_function_bench_topk_parallel.sh | 2 +- setup.py | 1 + 10 files changed, 1349 insertions(+), 778 deletions(-) create mode 100644 csrc/topk_sglang_cluster.cu diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index a717c7b5..8198d0c2 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -27,7 +27,8 @@ topk_output_sglang, # 2-stage radix approximate topk (unmapped baseline) topk_output_sglang_fused, # fused remap + 2-stage radix topk topk_output_sglang_ori, # original SGLang reference kernel - topk_output_sglang_parallel, # multi-CTA split+merge variant of the fused kernel + fast_fused_topk_merge, # single-kernel split+merge (new parallel kernel) + fast_cluster_topk_merge, # Hopper TBC+DSMEM fused split+merge (sm_90+) topk_remap_only, # standalone value-space remap topk_profile_histogram, topk_profile_counters, @@ -99,8 +100,16 @@ def _auto_num_splits(eff_batch_size: int, pages_per_seg: int, topk_val: int) -> except Exception: sm = 132 balanced = max(1, int(round((pages_per_seg / max(1, topk_val)) ** 0.5))) + # SM-budget floor is 1, but 1 means "don't split" — pointless for the + # parallel kernel and would ask Phase-1 to cache the entire seq in + # shared memory (blows past the 96 KB ceiling). Clamp to at least 2 + # whenever max_safe allows it; the caller will skip parallel entirely + # if it really doesn't want to split. sm_budget = max(1, sm // max(1, eff_batch_size)) - return max(1, min(balanced, sm_budget, max_safe)) + choice = min(balanced, sm_budget, max_safe) + if choice < 2 and max_safe >= 2: + choice = 2 + return max(1, choice) def _load_autotune_hparams(path: str) -> Dict[int, float]: @@ -628,32 +637,101 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, inputs["sparse_kv_indices"].zero_() fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) - # Multi-CTA split+merge variant of the fused kernel. num_splits <= 1 - # delegates to the single-CTA fused path, so this is only a - # meaningful extra data point when we can actually split. + # New single-kernel split+merge (fast_fused_topk_merge). Takes a + # dense [B, N, chunk] score tensor and writes [B, K] int32 + # indices. We reshape the bench's [eff_bs * pages_per_seg] flat + # scores into [eff_bs, num_splits, chunk_per_split] and compare + # the resulting batch-local indices against the fused kernel's + # (which map through the identity dense_kv_indices). parallel_ms = None parallel_splits_used = None - if getattr(args, "bench_parallel", False): + cluster_ms = None + if getattr(args, "bench_parallel", False) and mode in { + 3, 6, 7, 9, 10, 11, 13, 15, 16, 17 + }: splits = getattr(args, "num_splits", -1) if splits is None or splits < 1: splits = _auto_num_splits(eff_bs, pages_per_seg, topk_val) + # num_splits must divide pages_per_seg, and num_splits*topk_val + # must fit the merge cap (8192). Clamp + round. + if splits > 1 and pages_per_seg % splits != 0: + # snap down to the largest divisor ≤ splits + for cand in range(splits, 0, -1): + if pages_per_seg % cand == 0: + splits = cand + break + while splits * topk_val > 8192 and splits > 1: + splits //= 2 + # splits=1 means "no parallel" — the parallel kernel has no + # work to do and would ask for seq_len * 5 bytes of smem (the + # Phase-1 cache), blowing past the 96 KB ceiling at seq_len + # > ~19K. Skip the parallel timing row for this config. + if splits < 2: + parallel_ms = None + parallel_splits_used = None + row = { + "mode": mode, + "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), + "power": power, + "remap_ms": None, + "topk_after_remap_ms": None, + "split_total_ms": None, + "fused_ms": fused["mean_ms"], + "parallel_ms": parallel_ms, + "parallel_splits": parallel_splits_used, + "cluster_ms": cluster_ms, + **_collect_threshold_stats( + inputs, topk_val, pages_per_seg, args, mode, power + ), + } + config["modes"].append(row) + continue + chunk_per_split = pages_per_seg // splits + parallel_x = ( + inputs["x"].view(eff_bs, pages_per_seg) + .view(eff_bs, splits, chunk_per_split) + .contiguous() + ) + parallel_out = torch.empty(eff_bs, topk_val, + dtype=torch.int32, device="cuda") parallel_args = ( - inputs["x"], - inputs["dense_kv_indptr"], - inputs["sparse_kv_indptr"], - inputs["dense_kv_indices"], - inputs["sparse_kv_indices"], - eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + parallel_x, + parallel_out, + eff_bs, splits, - mode, power, lut_t, q_t, + chunk_per_split, + topk_val, + mode, + power, ) - inputs["sparse_kv_indices"].zero_() parallel = bench_kernel( - topk_output_sglang_parallel, parallel_args, args.warmup, args.repeat + fast_fused_topk_merge, parallel_args, args.warmup, args.repeat ) parallel_ms = parallel["mean_ms"] parallel_splits_used = splits + # Hopper TBC+DSMEM variant — same args, sm_90+ only, + # cluster cap = 8. Fresh output buffer so validation can + # compare against the parallel kernel's output independently. + cluster_ms = None + if splits <= 8 and torch.cuda.get_device_capability(0)[0] >= 9: + cluster_out = torch.empty(eff_bs, topk_val, + dtype=torch.int32, device="cuda") + cluster_args = ( + parallel_x, + cluster_out, + eff_bs, + splits, + chunk_per_split, + topk_val, + mode, + power, + ) + cluster = bench_kernel( + fast_cluster_topk_merge, cluster_args, args.warmup, args.repeat + ) + cluster_ms = cluster["mean_ms"] + # Split-phase timing is only meaningful for arithmetic modes. # MAPPING_LUT_CDF / QUANTILE / TRUNC8 apply their mapping inside # compute_stage1_bin, which topk_remap_only cannot reproduce, so we @@ -706,6 +784,7 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "fused_ms": fused["mean_ms"], "parallel_ms": parallel_ms, "parallel_splits": parallel_splits_used, + "cluster_ms": cluster_ms, **stats, } config["modes"].append(row) @@ -725,8 +804,9 @@ def _print_remap_table(results: List[dict]) -> None: # still in the JSON for downstream tools, just not in the table. header = ( f"{'mode':<14s} {'remap_ms':>9s} {'topk_ms':>9s} {'split_ms':>9s} " - f"{'fused_ms':>9s} {'par_ms':>9s} {'splits':>6s} {'base_ms':>9s} " - f"{'s1p2_load':>9s} {'eff_thr':>7s} {'rounds':>6s} {'s2_work':>8s}" + f"{'fused_ms':>9s} {'par_ms':>9s} {'cluster_ms':>10s} {'splits':>6s} " + f"{'base_ms':>9s} {'s1p2_load':>9s} {'eff_thr':>7s} {'rounds':>6s} " + f"{'s2_work':>8s}" ) for cfg in results: banner = ( @@ -767,6 +847,7 @@ def _fmt(v): return f"{v:9.4f}" if v is not None else f"{'N/A':>9s}" fused_str = _fmt(row.get("fused_ms")) par_str = _fmt(row.get("parallel_ms")) + cluster_str = f"{row.get('cluster_ms'):10.4f}" if row.get("cluster_ms") is not None else f"{'N/A':>10s}" splits = row.get("parallel_splits") splits_str = f"{splits:>6d}" if splits is not None else f"{'N/A':>6s}" thr_size = row.get("threshold_bin_size_mean", 0.0) @@ -781,6 +862,7 @@ def _fmt(v): f"{_fmt(row['split_total_ms'])} " f"{fused_str} " f"{par_str} " + f"{cluster_str} " f"{splits_str} " f"{base_ms:9.4f} " f"{s1p2_load:9.0f} " @@ -834,7 +916,7 @@ def _mean_or_none(vals): # across heads by construction. NUMERIC_KEYS = ( "remap_ms", "topk_after_remap_ms", "split_total_ms", "fused_ms", - "parallel_ms", + "parallel_ms", "cluster_ms", "threshold_bin_mean", "threshold_bin_max", "num_above_mean", "threshold_bin_size_mean", "threshold_bin_size_max", @@ -1052,7 +1134,7 @@ def main(): p.add_argument("--remap-bench", action="store_true", help="Run the split-phase remap/topk/fused/baseline benchmark.") p.add_argument("--bench-parallel", action="store_true", - help="Also time topk_output_sglang_parallel (multi-CTA split+merge).") + help="Also time fast_fused_topk_merge (single-kernel split+merge).") p.add_argument("--num-splits", type=int, default=-1, help="Partitions per batch for the parallel kernel. -1 = auto " "(sm_count / eff_batch_size, clamped to pages_per_seg/topk_val).") diff --git a/csrc/register.cc b/csrc/register.cc index af584d37..cf1be880 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -30,17 +30,16 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_power"), py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none()); - m.def("topk_output_sglang_parallel", &topk_output_sglang_parallel, - py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), - py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), - py::arg("eff_batch_size"), py::arg("topk_val"), - py::arg("reserved_bos"), py::arg("reserved_eos"), - py::arg("max_num_pages"), - py::arg("num_splits"), - py::arg("mapping_mode"), - py::arg("mapping_power"), - py::arg("mapping_lut") = py::none(), - py::arg("mapping_quantiles") = py::none()); + m.def("fast_fused_topk_merge", &fast_fused_topk_merge, + py::arg("score"), py::arg("global_topk_indices"), + py::arg("batch_size"), py::arg("num_chunks"), + py::arg("chunk_size"), py::arg("topk_val"), + py::arg("mapping_mode"), py::arg("mapping_power")); + m.def("fast_cluster_topk_merge", &fast_cluster_topk_merge, + py::arg("score"), py::arg("global_topk_indices"), + py::arg("batch_size"), py::arg("num_chunks"), + py::arg("chunk_size"), py::arg("topk_val"), + py::arg("mapping_mode"), py::arg("mapping_power")); m.def("topk_remap_only", &topk_remap_only, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("remapped"), diff --git a/csrc/register.h b/csrc/register.h index e5a26def..ad2baaca 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -126,22 +126,33 @@ std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt ); -void topk_output_sglang_parallel( -const at::Tensor& x, -const at::Tensor& dense_kv_indptr, -const at::Tensor& sparse_kv_indptr, -const at::Tensor& dense_kv_indices, -at::Tensor& sparse_kv_indices, -const int64_t eff_batch_size, +// Two-stage parallel TopK. See csrc/topk_sglang_parallel.cu. +// score: [batch_size, num_chunks, chunk_size] bfloat16 or float32 +// global_topk_indices: [batch_size, topk_val] int32 (output) +// Caller must ensure num_chunks * topk_val <= 8192 (merge smem cap). +void fast_fused_topk_merge( +const at::Tensor& score, +at::Tensor& global_topk_indices, +const int64_t batch_size, +const int64_t num_chunks, +const int64_t chunk_size, const int64_t topk_val, -const int64_t reserved_bos, -const int64_t reserved_eos, -const int64_t max_num_pages, -const int64_t num_splits, const int64_t mapping_mode, -const double mapping_power, -std::optional mapping_lut = std::nullopt, -std::optional mapping_quantiles = std::nullopt +const double mapping_power +); + +// Hopper TBC+DSMEM fused TopK merge. See csrc/topk_sglang_cluster.cu. +// Same signature as fast_fused_topk_merge; num_chunks is the cluster +// size and is capped at 8 (portable TBC). Requires sm_90+. +void fast_cluster_topk_merge( +const at::Tensor& score, +at::Tensor& global_topk_indices, +const int64_t batch_size, +const int64_t num_chunks, +const int64_t chunk_size, +const int64_t topk_val, +const int64_t mapping_mode, +const double mapping_power ); void topk_remap_only( diff --git a/csrc/topk_mapping.cuh b/csrc/topk_mapping.cuh index c645acbf..0b6474bc 100644 --- a/csrc/topk_mapping.cuh +++ b/csrc/topk_mapping.cuh @@ -23,8 +23,8 @@ enum TopKMappingMode { MAPPING_NONE = 0, // identity (no remap) - MAPPING_LUT_CDF = 1, // bin lookup: new_bin = lut[convert_to_uint8(x)] - MAPPING_QUANTILE = 2, // binary search over 256 calibrated quantile thresholds + // MAPPING_LUT_CDF = 1, // bin lookup: new_bin = lut[convert_to_uint8(x)] + // MAPPING_QUANTILE = 2, // binary search over 256 calibrated quantile thresholds MAPPING_POWER = 3, // sign(x) * |x|^p MAPPING_LOG = 4, // sign(x) * log(|x| + 1) MAPPING_ASINH = 6, // asinh(beta * x) @@ -50,7 +50,7 @@ enum TopKMappingMode { // per exponent slot instead of 4. Zero per-element compute overhead; // the "remap" is the bucket change. Monotonic within 2 adjacent // fp32 exponent slots. - MAPPING_DENSE_MANT = 20, // identity; bucketing handled in fused kernel + // MAPPING_DENSE_MANT = 20, // identity; bucketing handled in fused kernel }; struct TopKMappingParams { @@ -152,14 +152,10 @@ __device__ __forceinline__ float apply_transform_tmpl(float x, float p) { else if constexpr (MODE == MAPPING_LINEAR_STEEP) return transform_linear_steep(x, p); else if constexpr (MODE == MAPPING_HALF_SQUARE) return transform_half_square(x, p); else if constexpr (MODE == MAPPING_HALF_CUBE) return transform_half_cube(x, p); - else if constexpr (MODE == MAPPING_DENSE_MANT) return fmaxf(x, p); else return x; // NONE / TRUNC8 } // Pure element-wise dispatcher. Returns the *float value* after the transform. -// For bin-selection modes (LUT_CDF / QUANTILE) this is identity: the mapping -// happens in compute_stage1_bin() below instead of via a float transform, so -// Stage-2 tie-breaking uses the raw score bits for those modes. __device__ __forceinline__ float apply_transform(float x, const TopKMappingParams& params) { switch (params.mode) { case MAPPING_POWER: return transform_power(x, params.power_exp); @@ -175,23 +171,15 @@ __device__ __forceinline__ float apply_transform(float x, const TopKMappingParam case MAPPING_LINEAR_STEEP: return transform_linear_steep(x, params.power_exp); case MAPPING_HALF_SQUARE: return transform_half_square(x, params.power_exp); case MAPPING_HALF_CUBE: return transform_half_cube(x, params.power_exp); - // MAPPING_DENSE_MANT clamps small/negative values to `power_exp` - // (default 0.5) so the subsequent dense bit bucket in the fused - // kernel sees a narrow 1–2 exponent window of positive values. - // Values at/below the clamp all hash to the lowest bin, which - // is always below the topk threshold in practice. - case MAPPING_DENSE_MANT: return fmaxf(x, params.power_exp); - case MAPPING_LUT_CDF: - case MAPPING_QUANTILE: case MAPPING_TRUNC8: - default: return x; // NONE / TRUNC8 / LUT_CDF / QUANTILE + default: return x; // NONE / TRUNC8 } } -// Whether the mapping mode is a direct bin-selection function (LUT_CDF / -// QUANTILE). These modes need per-block shared-memory tables. -__device__ __forceinline__ bool mapping_uses_table(int mode) { - return mode == MAPPING_LUT_CDF || mode == MAPPING_QUANTILE; +// Bin-selection table modes (LUT_CDF / QUANTILE) have been retired. +// This helper is kept for ABI compat with callers that still invoke it. +__device__ __forceinline__ bool mapping_uses_table(int /*mode*/) { + return false; } // Binary search over a sorted [256] quantile table. Returns the largest @@ -212,21 +200,14 @@ __device__ __forceinline__ uint8_t quantile_bin_lookup( // Forward decl so compute_stage1_bin can call it. Defined in the enclosing TU. __device__ __forceinline__ uint8_t convert_to_uint8(float x); -// Compute the Stage-1 bin for a raw score under any mapping mode. LUT_CDF / -// QUANTILE use the shared-memory tables loaded at the kernel entry; every -// other mode falls back to convert_to_uint8(apply_transform(x)). +// Compute the Stage-1 bin for a raw score. LUT_CDF / QUANTILE modes +// have been removed; every mode now goes through the element-wise +// apply_transform + convert_to_uint8. __device__ __forceinline__ uint8_t compute_stage1_bin( float raw, const TopKMappingParams& params, - const uint8_t* __restrict__ s_lut, - const float* __restrict__ s_quantiles) + const uint8_t* __restrict__ /*s_lut*/, + const float* __restrict__ /*s_quantiles*/) { - switch (params.mode) { - case MAPPING_LUT_CDF: - return s_lut[convert_to_uint8(raw)]; - case MAPPING_QUANTILE: - return quantile_bin_lookup(raw, s_quantiles); - default: - return convert_to_uint8(apply_transform(raw, params)); - } + return convert_to_uint8(apply_transform(raw, params)); } diff --git a/csrc/topk_sglang.cu b/csrc/topk_sglang.cu index 73366dfa..cc8c6b37 100644 --- a/csrc/topk_sglang.cu +++ b/csrc/topk_sglang.cu @@ -706,7 +706,8 @@ __device__ void fast_topk_clean_fused( // path stays in place for standard modes. LUT_CDF / QUANTILE are not // supported by this templated kernel (they were dropped from the bench // comparison earlier). - constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); + // MAPPING_DENSE_MANT has been retired; always use the fp16 bucket. + constexpr bool use_dense_bucket = false; if (tx < RADIX + 1) f_histogram[tx] = 0; __syncthreads(); @@ -1280,7 +1281,6 @@ void topk_output_sglang_fused( case MAPPING_LINEAR_STEEP:VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ case MAPPING_HALF_SQUARE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ case MAPPING_HALF_CUBE: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ - case MAPPING_DENSE_MANT: VORTEX_DISPATCH_FUSED(DTYPE, PTR_EXPR, MAPPING_DENSE_MANT); break; \ default: \ TORCH_CHECK(false, "topk_output_sglang_fused: unsupported mapping_mode ", mapping.mode); \ } \ diff --git a/csrc/topk_sglang_cluster.cu b/csrc/topk_sglang_cluster.cu new file mode 100644 index 00000000..453f7bf8 --- /dev/null +++ b/csrc/topk_sglang_cluster.cu @@ -0,0 +1,684 @@ +/** + * Vortex TopK — Hopper Thread Block Cluster + Distributed Shared Memory + * single-kernel fused top-K merge. + * + * Grid = Batch * N CTAs. + * Cluster dim = N (runtime, set via cudaLaunchAttributeClusterDimension). + * Each cluster = one batch. cluster.block_rank() identifies the chunk. + * + * Stage 1 (every CTA): 8-bit radix + 8-bit refinement over its chunk, + * writing the local top-K (fp32 remapped score + int32 index) into THIS + * CTA's shared memory — never through global memory. + * + * Stage 2 (CTA 0 only): after cluster.sync(), read every CTA's + * s_export_scores / s_export_indices directly via + * cg::cluster_group::map_shared_rank() — the reads compile to + * `ld.shared::cluster`. Build a merged 8-bit histogram, find the + * coarse threshold, run the standard 8-bit refinement, and emit K + * indices to global memory using warp-popc compaction. + * + * A second cluster.sync() at the end guarantees no CTA exits while + * CTA 0 is still issuing DSMEM reads into its exported SMEM. + * + * sm_90+ only (Hopper, Blackwell). The kernel body is guarded by + * __CUDA_ARCH__ >= 900 so the file compiles cleanly against the + * sm_86/sm_89 gencode targets in setup.py — the host entrypoint + * TORCH_CHECKs the runtime device compute capability. + */ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include "register.h" + +namespace { + +constexpr int kThreadsPerBlock = 1024; +constexpr int kWarpSize = 32; +constexpr int RADIX = 256; +constexpr size_t kMaxDynSmem = 96 * 1024; +constexpr int VORTEX_MAX_TOPK = 2048; +constexpr int kMaxClusterDim = 8; // portable TBC cap + +__device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} + +// Required by topk_mapping.cuh's forward decl (even though the cluster +// kernel never calls compute_stage1_bin directly). +__device__ __forceinline__ uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); +} + +template +__device__ __forceinline__ float vortex_to_float(T x); +template <> +__device__ __forceinline__ float vortex_to_float(float x) { return x; } +template <> +__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); +} + +#include "topk_mapping.cuh" + +// 8-step suffix cumsum: after the call s_hist[0][i] = count of items +// with bin >= i (monotone non-increasing). Same routine as +// topk_sglang_parallel.cu. +__device__ __forceinline__ void run_cumsum_256(int s_hist[2][RADIX + 128]) { + const int tx = threadIdx.x; +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const int j = 1 << i; + const int k = i & 1; + int value = s_hist[k][tx]; + if (tx < RADIX - j) value += s_hist[k][tx + j]; + s_hist[k ^ 1][tx] = value; + } + __syncthreads(); + } +} + +// Warp-level ballot+popc compaction. Exactly one atomicAdd per warp, +// issued by the first active lane. Safe from a divergent region. +__device__ __forceinline__ int warp_compact_slot(bool selected, int* s_counter) { + const uint32_t mask = __activemask(); + const uint32_t ballot = __ballot_sync(mask, selected); + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_count = __popc(ballot); + const int rank_in_warp = __popc(ballot & ((1u << lane) - 1u)); + + const int first_lane = __ffs(mask) - 1; + int base = 0; + if (lane == first_lane) { + base = (warp_count > 0) ? ::atomicAdd(s_counter, warp_count) : 0; + } + base = __shfl_sync(mask, base, first_lane); + return selected ? (base + rank_in_warp) : -1; +} + +namespace cg = cooperative_groups; + +template +__global__ __launch_bounds__(kThreadsPerBlock) +void TopK_Cluster_Kernel( + const ScoreT* __restrict__ score, // [Batch, N, chunk_size] + int32_t* __restrict__ global_idx, // [Batch, K] + int N, + int chunk_size, + int K, + float mapping_power) +{ +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cg::cluster_group cluster = cg::this_cluster(); + const int rank = static_cast(cluster.block_rank()); + // Grid layout: dim3(Batch * N). blockIdx.x = b * N + rank. + const int b = (blockIdx.x - rank) / N; + const int tx = threadIdx.x; + + const ScoreT* chunk_in = score + (static_cast(b) * N + rank) * chunk_size; + const int32_t idx_base = rank * chunk_size; + + // Static SMEM ------------------------------------------------------------ + alignas(128) __shared__ int s_hist_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + alignas(128) __shared__ int s_export_count; + // Rank 0 only: contiguous staging buffer for the final K indices before + // the coalesced int4 write to global memory. Sized to VORTEX_MAX_TOPK so + // we don't need to carve it out of the dynamic smem layout (which must + // keep the exports at offset 0 for DSMEM visibility). + alignas(16) __shared__ int32_t s_final_indices[VORTEX_MAX_TOPK]; + auto& s_hist = s_hist_buf[0]; + + // Dynamic SMEM ------------------------------------------------------------ + // [0, K*4) s_export_scores (fp32) <- DSMEM-visible + // [K*4, K*8) s_export_indices (int32) <- DSMEM-visible + // [K*8, K*8 + overlay) Stage-1 cache on ALL ranks; reused by + // rank 0 in Stage 2 as the N*K merge pool. + // Stage 1 : s_remapped[chunk] (fp32) + s_bins[chunk] (uint8 padded) + // Stage 2 : s_merge_scores[N*K] (fp32) + s_merge_indices[N*K] (int32) + // + // Exports sit at the start of the SMEM pool so the base offset is the + // same on every cluster CTA — cg::map_shared_rank uses that offset + // modulo the cluster stride to read a remote CTA. + extern __shared__ char smem_raw[]; + float* s_export_scores = reinterpret_cast (smem_raw); + int32_t* s_export_indices = reinterpret_cast(smem_raw + K * sizeof(float)); + float* s_remapped = reinterpret_cast (smem_raw + K * (sizeof(float) + sizeof(int32_t))); + uint8_t* s_bins = reinterpret_cast(s_remapped + chunk_size); + + // Initialize counters + pad export indices to -1 (so the degenerate + // chunk_size < K case leaves recognisable empty slots). + for (int i = tx; i < K; i += blockDim.x) { + s_export_indices[i] = -1; + s_export_scores [i] = -CUDART_INF_F; + } + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + s_export_count = 0; + } + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + // ========================================================================= + // Stage 1 — local top-K for this chunk. + // ========================================================================= + if (chunk_size <= K) { + // Degenerate: emit every valid element. + for (int idx = tx; idx < chunk_size; idx += blockDim.x) { + const float raw = vortex_to_float(chunk_in[idx]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const int slot = warp_compact_slot(true, &s_counter); + if (slot >= 0 && slot < K) { + s_export_scores [slot] = remapped; + s_export_indices[slot] = idx + idx_base; + } + } + __syncthreads(); + if (tx == 0) s_export_count = min(s_counter, K); + } else { + // Histogram pass 1 ------------------------------------------------------ + for (int idx = tx; idx < chunk_size; idx += blockDim.x) { + const float raw = vortex_to_float(chunk_in[idx]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t b32 = convert_to_uint32(remapped); + const int bin = (b32 >> 24) & 0xFF; + s_remapped[idx] = remapped; + s_bins [idx] = static_cast(bin); + ::atomicAdd(&s_hist[bin], 1); + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin = s_threshold_bin; + + // Emit bin > threshold; build sub-bin histogram on the tie bin ----- + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + const int num_iters = (chunk_size + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + const bool take_above = in_range && (bin > threshold_bin); + + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + s_export_scores [slot] = s_remapped[idx]; + s_export_indices[slot] = idx + idx_base; + } else if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + // Refinement cumsum → sub-threshold bin -------------------------------- + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + s_sub_threshold_bin = RADIX; // no tie refinement needed + } + __syncthreads(); + const int sub_threshold_bin = s_sub_threshold_bin; + + // Emit tie-bin items --------------------------------------------------- + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + int sub_bin = -1; + if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + sub_bin = (b32 >> 16) & 0xFF; + } + + const bool take_sub_above = (sub_bin > sub_threshold_bin); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + s_export_scores [slot] = s_remapped[idx]; + s_export_indices[slot] = idx + idx_base; + } else if (sub_bin == sub_threshold_bin) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + s_export_scores [K - pos] = s_remapped[idx]; + s_export_indices[K - pos] = idx + idx_base; + } + } + } + __syncthreads(); + if (tx == 0) s_export_count = K; + } + + // ========================================================================= + // Stage 2 — CENTRALIZED RANK-0 PULL. + // + // Control flow: + // barrier #1 (all CTAs) : release all Stage-1 exports cluster-wide. + // rank != 0 : wait at barrier #2 so their exported SMEM + // stays alive while rank 0 pulls, then exit. + // rank 0 Step A : vectorised DSMEM pull of every rank's + // s_export_* into a local N*K merge pool. + // rank 0 Step B+C : single-block 8-bit radix select over the + // merge pool, staging winners into + // s_final_indices via warp-popc + LOCAL + // atomicAdd on &s_counter / &s_last_remain. + // (No DSMEM atomics anywhere.) + // rank 0 Step D : int4-coalesced global store of + // s_final_indices[K] → global_idx[b, :K]. + // barrier #2 (all CTAs) : release idle ranks. + // ========================================================================= + + // cluster.sync() is both a cross-CTA barrier AND a cluster-wide release + // fence on shared memory, so rank 0's upcoming DSMEM reads of remote + // s_export_* observe the Stage-1 writes above. + cluster.sync(); + + if (rank != 0) { + cluster.sync(); // final barrier — keeps SMEM alive during rank 0's pull + return; + } + + // ---- rank 0 only from here on ------------------------------------------ + + // Pre-fill the staging buffer with -1 so that if fewer than K valid + // candidates exist, the unused tail emits as -1 sentinels rather than + // stale static-SMEM data. + for (int i = tx; i < K; i += blockDim.x) s_final_indices[i] = -1; + + // Reset histogram + counters for Stage 2's radix select. + if (tx < RADIX + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + + // Merge pool: overlays the now-dead Stage-1 cache region. Layout: + // [K*8, K*8 + N*K*4) s_merge_scores (fp32) + // [K*8 + N*K*4, K*8 + N*K*8) s_merge_indices (int32) + const int total = N * K; + float* s_merge_scores = reinterpret_cast (smem_raw + K * sizeof(float) + + K * sizeof(int32_t)); + int32_t* s_merge_indices = reinterpret_cast(s_merge_scores + total); + __syncthreads(); + + // ========================================================================= + // Step A — vectorised DSMEM pull. + // + // map_shared_rank(ptr, 0) degenerates to a local load, so we can sweep + // r=0..N-1 uniformly without special-casing the self-copy. + // ========================================================================= + #pragma unroll + for (int r = 0; r < kMaxClusterDim; ++r) { + if (r >= N) break; + const float* rem_scores = cluster.map_shared_rank(s_export_scores, r); + const int32_t* rem_indices = cluster.map_shared_rank(s_export_indices, r); + float* dst_scores = s_merge_scores + r * K; + int32_t* dst_indices = s_merge_indices + r * K; + + if ((K & 3) == 0) { + const float4* src_s4 = reinterpret_cast(rem_scores); + const int4* src_i4 = reinterpret_cast(rem_indices); + float4* dst_s4 = reinterpret_cast (dst_scores); + int4* dst_i4 = reinterpret_cast (dst_indices); + const int K4 = K >> 2; + for (int i = tx; i < K4; i += blockDim.x) { + dst_s4[i] = src_s4[i]; + dst_i4[i] = src_i4[i]; + } + } else { + for (int i = tx; i < K; i += blockDim.x) { + dst_scores [i] = rem_scores [i]; + dst_indices[i] = rem_indices[i]; + } + } + } + __syncthreads(); + + // ========================================================================= + // Step B+C — local 8-bit radix select over the N*K merge pool, with + // warp-popc compaction into s_final_indices. Ported from + // topk_sglang_parallel.cu Phase 2. + // ========================================================================= + + // (1) Coarse 8-bit histogram on bits [31:24] of the sign-flipped score. + const int num_iters_m = (total + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + if (i < total && s_merge_indices[i] >= 0) { + const uint32_t b32 = convert_to_uint32(s_merge_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_hist[bin], 1); + } + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + + // Fast path: fewer valid candidates than K — emit them all, skip refinement. + const int valid_count = s_hist[0]; + if (valid_count <= K) { + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + const bool take = (i < total) && (s_merge_indices[i] >= 0); + const int slot = warp_compact_slot(take, &s_counter); + if (take && slot < K) s_final_indices[slot] = s_merge_indices[i]; + } + } else { + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin_m = s_threshold_bin; + + // (2) Emit above-threshold winners; build sub-bin histogram on tie-bin. + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_valid = false; + int bin = -1; + uint32_t b32 = 0; + if (i < total) { + const int32_t idx = s_merge_indices[i]; + if (idx >= 0) { + in_valid = true; + b32 = convert_to_uint32(s_merge_scores[i]); + bin = (b32 >> 24) & 0xFF; + } + } + const bool take_above = in_valid && (bin > threshold_bin_m); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + s_final_indices[slot] = s_merge_indices[i]; + } else if (in_valid && bin == threshold_bin_m) { + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + // (3) Refinement cumsum → sub-threshold bin. + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + s_sub_threshold_bin = RADIX; // no tie-bin refinement needed + } + __syncthreads(); + const int sub_threshold_bin_m = s_sub_threshold_bin; + + // (4) Emit tie-bin items: hard wins via warp-popc, remainder via local + // atomic budget. Both atomics hit rank-0's native SMEM only. + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_threshold = false; + int sub_bin = -1; + if (i < total) { + const int32_t idx = s_merge_indices[i]; + if (idx >= 0) { + const uint32_t b32 = convert_to_uint32(s_merge_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + if (bin == threshold_bin_m) { + in_threshold = true; + sub_bin = (b32 >> 16) & 0xFF; + } + } + } + const bool take_sub_above = in_threshold && (sub_bin > sub_threshold_bin_m); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + s_final_indices[slot] = s_merge_indices[i]; + } else if (in_threshold && sub_bin == sub_threshold_bin_m) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) s_final_indices[K - pos] = s_merge_indices[i]; + } + } + } + + __syncthreads(); + + // ========================================================================= + // Step D — coalesced int4 store of s_final_indices[K] → global_idx[b, :K]. + // ========================================================================= + int32_t* out_idx = global_idx + static_cast(b) * K; + if ((K & 3) == 0) { + const int4* src = reinterpret_cast(s_final_indices); + int4* dst = reinterpret_cast (out_idx); + const int K4 = K >> 2; + for (int i = tx; i < K4; i += blockDim.x) dst[i] = src[i]; + } else { + for (int i = tx; i < K; i += blockDim.x) out_idx[i] = s_final_indices[i]; + } + + // Final barrier: releases ranks 1..N-1 that were holding their SMEM + // alive while rank 0 was pulling in Step A. + cluster.sync(); +#else + // sm_86/sm_89 fallback: host dispatcher TORCH_CHECKs compute + // capability, so this stub is never actually invoked. The empty + // body still needs to reference the params so nvcc doesn't warn. + (void)score; (void)global_idx; + (void)N; (void)chunk_size; (void)K; (void)mapping_power; +#endif +} + +// One-shot cudaFuncSetAttribute for dynamic smem ceiling. +template +void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + return ::cudaFuncSetAttribute( + f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + }(); + TORCH_CHECK(result == cudaSuccess, + "fast_cluster_topk_merge setup failed: ", + ::cudaGetErrorString(result)); +} + +} // namespace + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +// ============================================================================ +// Host entry point — fast_cluster_topk_merge. +// +// score [batch_size, num_chunks, chunk_size] bf16 or f32 +// global_topk_indices [batch_size, topk_val] int32 (out) +// +// No workspace tensors — Stage-1 partial top-K lives in shared memory, +// consumed by CTA 0 of the cluster via DSMEM. +// ============================================================================ +void fast_cluster_topk_merge( + const at::Tensor& score, + at::Tensor& global_topk_indices, + const int64_t batch_size, + const int64_t num_chunks, + const int64_t chunk_size, + const int64_t topk_val, + const int64_t mapping_mode, + const double mapping_power) +{ + CHECK_CUDA(score); + CHECK_CUDA(global_topk_indices); + + TORCH_CHECK(topk_val > 0 && topk_val <= VORTEX_MAX_TOPK, + "fast_cluster_topk_merge: topk_val=", topk_val, + " must be in (0, ", VORTEX_MAX_TOPK, "]"); + TORCH_CHECK(num_chunks >= 1 && num_chunks <= kMaxClusterDim, + "fast_cluster_topk_merge: num_chunks=", num_chunks, + " must be in [1, ", kMaxClusterDim, "] (portable TBC cap)"); + TORCH_CHECK(batch_size >= 1, "batch_size must be >= 1"); + TORCH_CHECK(chunk_size >= 1, "chunk_size must be >= 1"); + TORCH_CHECK(global_topk_indices.scalar_type() == at::kInt, + "global_topk_indices must be int32"); + TORCH_CHECK(global_topk_indices.numel() >= batch_size * topk_val, + "global_topk_indices is too small for batch_size * topk_val"); + + TORCH_CHECK( + mapping_mode == MAPPING_NONE || + mapping_mode == MAPPING_POWER || + mapping_mode == MAPPING_ASINH || + mapping_mode == MAPPING_LOG1P || + mapping_mode == MAPPING_ERF || + mapping_mode == MAPPING_TANH || + mapping_mode == MAPPING_SUBTRACT || + mapping_mode == MAPPING_EXP_STRETCH || + mapping_mode == MAPPING_SHIFT_POW2 || + mapping_mode == MAPPING_SHIFT_POW3 || + mapping_mode == MAPPING_LINEAR_STEEP, + "fast_cluster_topk_merge: mapping_mode=", mapping_mode, + " not supported. Valid: NONE(0), POWER(3), ASINH(6), LOG1P(7), " + "ERF(9), TANH(10), SUBTRACT(11), EXP_STRETCH(13), SHIFT_POW2(15), " + "SHIFT_POW3(16), LINEAR_STEEP(17)."); + + // Hardware capability gate — Thread Block Clusters require sm_90+. + int dev; + TORCH_CHECK(::cudaGetDevice(&dev) == cudaSuccess, "cudaGetDevice failed"); + cudaDeviceProp prop{}; + TORCH_CHECK(::cudaGetDeviceProperties(&prop, dev) == cudaSuccess, + "cudaGetDeviceProperties failed"); + TORCH_CHECK(prop.major >= 9, + "fast_cluster_topk_merge requires sm_90+ (Hopper/Blackwell). " + "Detected compute capability ", prop.major, ".", prop.minor, "."); + + // Dynamic smem layout (per CTA): + // exports : topk_val * (float + int32) = topk_val * 8 B (DSMEM-visible) + // overlay : used by Stage 1 as the remap/bin cache (all ranks), reused + // by Stage 2 on rank 0 as the N*K merge pool. Sized to the + // larger of the two so either fits. + // cache_bytes = chunk_size * (float + uint8), uint8 region padded. + // merge_bytes = num_chunks * topk_val * (float + int32). + const size_t export_bytes = static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + const size_t cache_bytes = static_cast(chunk_size) * sizeof(float) + + ((static_cast(chunk_size) + 15) & ~size_t(15)); + const size_t merge_bytes = static_cast(num_chunks) * + static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + const size_t overlay_bytes = (cache_bytes > merge_bytes) ? cache_bytes : merge_bytes; + const size_t smem_bytes = export_bytes + overlay_bytes; + TORCH_CHECK(smem_bytes <= kMaxDynSmem, + "fast_cluster_topk_merge: smem ", smem_bytes, + " > ceiling ", kMaxDynSmem, + " (topk_val=", topk_val, ", num_chunks=", num_chunks, + ", chunk_size=", chunk_size, ")"); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + const dim3 grid(static_cast(batch_size * num_chunks), 1, 1); + const dim3 block(kThreadsPerBlock, 1, 1); + + cudaLaunchAttribute attrs[1]{}; + attrs[0].id = cudaLaunchAttributeClusterDimension; + attrs[0].val.clusterDim.x = static_cast(num_chunks); + attrs[0].val.clusterDim.y = 1; + attrs[0].val.clusterDim.z = 1; + + cudaLaunchConfig_t cfg{}; + cfg.gridDim = grid; + cfg.blockDim = block; + cfg.dynamicSmemBytes = smem_bytes; + cfg.stream = stream; + cfg.attrs = attrs; + cfg.numAttrs = 1; + + #define LAUNCH(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, \ + kMaxDynSmem>(); \ + const auto rc_launch = ::cudaLaunchKernelEx( \ + &cfg, TopK_Cluster_Kernel, \ + PTR_EXPR, \ + global_topk_indices.data_ptr(), \ + static_cast(num_chunks), \ + static_cast(chunk_size), \ + static_cast(topk_val), \ + mp); \ + TORCH_CHECK(rc_launch == cudaSuccess, \ + "fast_cluster_topk_merge launch failed: ", \ + ::cudaGetErrorString(rc_launch)); \ + } while (0) + + #define DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping_mode) { \ + case MAPPING_NONE: LAUNCH(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: LAUNCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_ASINH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + default: TORCH_CHECK(false, "unreachable mode"); \ + } \ + } while (0) + + if (score.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MODE(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(score.data_ptr())); + } else if (score.scalar_type() == at::ScalarType::Float) { + DISPATCH_MODE(float, score.data_ptr()); + } else { + TORCH_CHECK(false, "fast_cluster_topk_merge: unsupported dtype ", + score.scalar_type()); + } + + #undef DISPATCH_MODE + #undef LAUNCH + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "fast_cluster_topk_merge kernel failed: ", ::cudaGetErrorString(rc)); +} diff --git a/csrc/topk_sglang_parallel.cu b/csrc/topk_sglang_parallel.cu index 72193917..e728940d 100644 --- a/csrc/topk_sglang_parallel.cu +++ b/csrc/topk_sglang_parallel.cu @@ -1,32 +1,37 @@ /** - * Vortex TopK parallel kernel (single-kernel, last-CTA-wins merge). + * Vortex TopK — single-kernel parallel+merge pipeline. * - * Motivation: the single-CTA fused kernel in topk_sglang.cu pins each - * batch segment to one CTA, which underutilises the GPU for small - * effective batch sizes (e.g. bs=4 on H100 leaves ~97% of SMs idle). + * ONE kernel launch. Per-chunk selection and cross-chunk merge both run + * inside the same grid-(N, Batch) launch. The last-arriving CTA for + * each batch (detected by a program-lifetime __device__ done-counter + + * atomicInc wrap-around) carries out the merge — no second launch, no + * per-call cudaMemset for barrier state. * - * This kernel launches `num_splits * eff_batch_size` CTAs in a single - * launch. CTAs sharing the same `bx` (batch index) partition that - * batch's score range `num_splits` ways and each compute a per-partition - * top-K via the same two-stage radix the fused kernel uses. Partial - * results are written into a per-batch workspace. + * Correctness: + * Stage 1 per-chunk uses ONE 8-bit radix histogram + ONE 8-bit + * refinement round on the threshold bin (16 bits of selection + * precision). For bf16 input (8 mantissa bits effective), this is + * lossless — two items with the same 16-bit key are bit-identical as + * bf16 values. * - * Merge is done WITHOUT a second kernel launch. Each CTA, after - * finishing its partition's top-K, does `atomicAdd(&done_counter[bx], - * 1)`. The CTA whose atomicAdd returns `num_splits - 1` is the last - * one to arrive for batch bx, and it alone carries out the merge: - * reads the `num_splits * topk_val` candidates from the workspace, - * runs a small two-stage radix on the already-remapped keys, writes - * final top-K page IDs to sparse_kv_indices. + * Stage 2 merge operates on N*K pre-remapped keys in shared memory + * and uses the same 8-bit-hist + 8-bit-refine pattern, which is + * strictly sufficient to pick the correct top-K from the union. * - * Correctness: per-partition top-K is a conservative upper bound on - * the global top-K (worst case: all top-K items land in one - * partition). Every global top-K item is therefore guaranteed to be - * in some partition's top-K, and the merge picks the final top-K - * from the union — sorted-scores match the fused kernel exactly. - * Tie-breaking can differ because radix tie-breaks depend on atomic - * race order. + * Low-overhead primitives: + * - Warp-level ballot+popc compaction on the "bin > threshold" path + * so each warp issues ONE atomicAdd on the block counter instead + * of one per thread. + * - Program-lifetime __device__ done-counter sized for realistic + * batch×head counts; atomicInc wraps back to 0 at num_chunks so + * there's no memset on the hot path. + * - Vectorised float4/int4 loads from global → smem in the merge. + * + * Supported mapping modes (IDs from csrc/topk_mapping.cuh): + * 3=POWER, 6=ASINH, 7=LOG1P, 9=ERF, 10=TANH, 11=SUBTRACT, + * 13=EXP_STRETCH, 15=SHIFT_POW2, 16=SHIFT_POW3, 17=LINEAR_STEEP. */ + #include #include #include @@ -46,35 +51,35 @@ namespace { -// ---- Launch constants (match topk_sglang.cu) -------------------------------- +// ---- Launch constants ------------------------------------------------------ -constexpr int kThreadsPerBlock = 1024; +constexpr int kThreadsPerBlock = 1024; +constexpr int kWarpSize = 32; +constexpr int RADIX = 256; +constexpr size_t kMaxDynSmem = 96 * 1024; +constexpr int VORTEX_MAX_TOPK = 2048; -#ifdef USE_ROCM -#ifdef SGL_TOPK_DYNAMIC_SMEM_BYTES -constexpr size_t kSmem = static_cast(SGL_TOPK_DYNAMIC_SMEM_BYTES); -#else -constexpr size_t kSmem = 48 * 1024; -#endif -#else -constexpr size_t kSmem = 8 * 1024 * sizeof(uint32_t); // 32 KB -#endif +// Stage-2 holds N*K (key, idx) pairs in smem = 8 B/item. +constexpr int kMergeCap = 8192; -constexpr size_t kFusedSmemMax = 96 * 1024; // combined kernel dynamic smem ceiling -constexpr int VORTEX_MAX_TOPK = 2048; +// Max batch the single kernel can sequence. Sized for realistic +// bs×heads (decode). __device__ globals are zero-initialised at +// program start; atomicInc wrap-around keeps each entry at 0 between +// launches, so no host-side memset on the hot path. +constexpr int kMaxBatch = 8192; +__device__ unsigned int g_done_counter[kMaxBatch]; -// ---- Program-lifetime done-counter array ---------------------------------- -// Used by the last-CTA-wins barrier. __device__ linkage → zero-initialised at -// program startup. atomicInc(ptr, num_splits-1) cycles each entry back to 0 -// after every launch, so we never pay a cudaMemset on entry to the host fn. -// Sized for the largest realistic effective batch we'd ever run through the -// parallel kernel (decode bs×heads). Host validates the cap. -constexpr int kMaxParallelEffBs = 8192; -__device__ int g_parallel_done_counter[kMaxParallelEffBs]; +// ---- Device helpers -------------------------------------------------------- -// ---- Device helpers (duplicated from topk_sglang.cu) ----------------------- +__device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); +} -__device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { +// Required symbol for topk_mapping.cuh's compute_stage1_bin. Not used +// directly by the kernel body here, but the header includes a forward +// declaration that resolves against this definition at link time. +__device__ __forceinline__ uint8_t convert_to_uint8(float x) { __half h = __float2half_rn(x); uint16_t bits = __half_as_ushort(h); uint16_t key = (bits & 0x8000) ? static_cast(~bits) @@ -82,17 +87,6 @@ __device__ __forceinline__ auto convert_to_uint8(float x) -> uint8_t { return static_cast(key >> 8); } -__device__ __forceinline__ auto convert_to_uint32(float x) -> uint32_t { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); -} - -__device__ __forceinline__ auto convert_to_uint8_dense(float x) -> uint8_t { - const uint32_t bits = __float_as_uint(x); - const uint32_t key = (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); - return static_cast((key >> 16) & 0xFFu); -} - template __device__ __forceinline__ float vortex_to_float(T x); template <> @@ -105,554 +99,399 @@ __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) #include "topk_mapping.cuh" // ============================================================================ -// fast_topk_partition +// 8-step suffix cumsum over 256 bins. After the call s_hist[0][i] is +// the count of items with bin >= i (monotone non-increasing). +// ============================================================================ +__device__ __forceinline__ void run_cumsum_256(int s_hist[2][RADIX + 128]) { + const int tx = threadIdx.x; +#pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const int j = 1 << i; + const int k = i & 1; + int value = s_hist[k][tx]; + if (tx < RADIX - j) value += s_hist[k][tx + j]; + s_hist[k ^ 1][tx] = value; + } + __syncthreads(); + } +} + +// ============================================================================ +// Warp-level ballot+popc compaction. +// +// Every participating thread offers a boolean `selected`. Exactly ONE +// atomicAdd per warp — issued by the first active lane — reserves +// `warp_count` slots; other selected lanes derive their slot via a +// popc prefix sum. Safe when called from inside a divergent region +// (uses __activemask(), not a fixed all-ones mask). +// ============================================================================ +__device__ __forceinline__ int warp_compact_slot(bool selected, int* s_counter) { + const uint32_t mask = __activemask(); + const uint32_t ballot = __ballot_sync(mask, selected); + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_count = __popc(ballot); + const int rank_in_warp = __popc(ballot & ((1u << lane) - 1u)); + + const int first_lane = __ffs(mask) - 1; + int base = 0; + if (lane == first_lane) { + base = (warp_count > 0) ? ::atomicAdd(s_counter, warp_count) : 0; + } + base = __shfl_sync(mask, base, first_lane); + return selected ? (base + rank_in_warp) : -1; +} + +// ============================================================================ +// Combined kernel — Stage 1 (per-chunk) + barrier + Stage 2 (merge). // -// Per-partition two-stage radix. Same algorithm as the fused kernel's -// fast_topk_clean_fused in topk_sglang.cu, with identical mapping-mode -// dispatch and bucket selection. Returns slice-local indices of the -// top `target_k` elements in `index`. +// Grid = (Batch, N). One CTA per (batch, chunk). +// Block = kThreadsPerBlock = 1024. // -// Reuses the caller-provided extern shared memory region `f_input_idx` -// (2 × SMEM_INPUT_SIZE ints) and the `s_bins` byte cache immediately -// after it. The caller also supplies the static histogram / counter -// storage through the template's body — each device-function-private -// __shared__ declaration gets its own offset, but total static smem -// stays small enough to fit comfortably alongside the dynamic region. +// Shared-memory layout (reused across phases): +// Phase 1 needs: +// s_remapped[chunk_size] (float) — cached apply_transform output. +// s_bins[chunk_size] (uint8) — cached coarse bin. +// Merge needs: +// s_scores[N*K] (float) — pair buffer, loaded vectorised. +// s_indices[N*K] (int32) — pair buffer. +// kSmemBytes is sized to host max of both. +// +// Sync between phases: +// After Phase 1's workspace writes, __threadfence() publishes them, +// then thread 0 does `atomicInc(&g_done_counter[bx], N-1)` which +// cycles 0→1→…→N-1→0 so no reset is needed between calls. The CTA +// whose returned `old == N-1` is the last one — it falls through +// into the merge; other CTAs return. // ============================================================================ template -__device__ void fast_topk_partition( - const ScoreT* __restrict__ input, - int* __restrict__ index, - int* __restrict__ f_input_idx_raw, // 2 × SMEM_INPUT_SIZE ints - uint8_t* __restrict__ s_bins, // `length` bytes - int row_start, - int length, - int target_k, - const TopKMappingParams mapping) +__global__ __launch_bounds__(kThreadsPerBlock) +void TopK_Parallel_Kernel( + const ScoreT* __restrict__ score, // [Batch, N, chunk_size] + int32_t* __restrict__ global_idx, // [Batch, K] + float* __restrict__ partial_keys, // [Batch, N, K] workspace + int32_t* __restrict__ partial_idx, // [Batch, N, K] workspace + int N, + int chunk_size, + int K, + float mapping_power) { - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - - alignas(128) __shared__ int f_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int f_counter; - alignas(128) __shared__ int f_threshold_bin_id; - alignas(128) __shared__ int f_num_input[2]; - - auto& f_histogram = f_histogram_buf[0]; - - // Treat the caller's extern-smem region as two banks of SMEM_INPUT_SIZE ints. - auto f_input_idx = [&](int bank, int pos) -> int& { - return f_input_idx_raw[bank * SMEM_INPUT_SIZE + pos]; - }; - + const int b = blockIdx.x; + const int n = blockIdx.y; const int tx = threadIdx.x; - constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); + // Addresses for this CTA's chunk slice and its slot in the workspace. + const ScoreT* chunk_in = score + (static_cast(b) * N + n) * chunk_size; + float* chunk_keys_out = partial_keys + (static_cast(b) * N + n) * K; + int32_t* chunk_idx_out = partial_idx + (static_cast(b) * N + n) * K; + const int32_t idx_base = n * chunk_size; // batch-local offset - if (tx < RADIX + 1) f_histogram[tx] = 0; - __syncthreads(); + // ---------------------------------------------------------------- smem + extern __shared__ char smem_raw[]; - // Stage 1 pass 1: bin every element and cache the bin in s_bins so - // pass 2 doesn't re-load scores or re-apply the mapping. - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform_tmpl(raw, mapping.power_exp); - int bin; - if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(remapped)); - else bin = static_cast(convert_to_uint8(remapped)); - s_bins[idx] = static_cast(bin); - ::atomicAdd(&f_histogram[bin], 1); - } - __syncthreads(); - - const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = f_histogram_buf[k][tx]; - if (tx < RADIX - j) value += f_histogram_buf[k][tx + j]; - f_histogram_buf[k ^ 1][tx] = value; + // Shared-memory counters / histogram live in static smem so the + // Phase-1 and merge phases can share the same dynamic pool. + alignas(128) __shared__ int s_hist_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + alignas(128) __shared__ int s_is_last; + auto& s_hist = s_hist_buf[0]; + + // ========================================================================= + // Phase 1: per-chunk TopK via 8-bit radix + 8-bit refinement. + // ========================================================================= + // + // Dynamic smem region used as: + // s_remapped : chunk_size * 4 B (cached apply_transform output) + // s_bins : chunk_size * 1 B (cached Stage-1 bin) + // + // Refinement is a second 8-bit bucket on bits [23:16] of the + // sign-flipped u32 key, used to refine the threshold bin. 8 + 8 = + // 16 bits of selection precision → lossless for bf16. + float* s_remapped = reinterpret_cast(smem_raw); + uint8_t* s_bins = reinterpret_cast(s_remapped + chunk_size); + + // ---- Degenerate chunk_size <= K : emit everything as-is. ------------- + if (chunk_size <= K) { + for (int i = tx; i < K; i += blockDim.x) { + if (i < chunk_size) { + const float raw = vortex_to_float(chunk_in[i]); + chunk_keys_out[i] = apply_transform_tmpl(raw, mapping_power); + chunk_idx_out [i] = i + idx_base; + } else { + chunk_keys_out[i] = -CUDART_INF_F; + chunk_idx_out [i] = -1; } - __syncthreads(); } - }; + } else { + // ---- Histogram pass 1: transform + bucket; cache both to smem. ---- + if (tx < RADIX + 1) s_hist[tx] = 0; + if (tx == 0) { s_counter = 0; s_threshold_bin = -1; s_last_remain = 0; } + __syncthreads(); - run_cumsum(); - if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { - f_threshold_bin_id = tx; - f_num_input[0] = 0; - f_counter = 0; - } - __syncthreads(); + for (int idx = tx; idx < chunk_size; idx += blockDim.x) { + const float raw = vortex_to_float(chunk_in[idx]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t b32 = convert_to_uint32(remapped); + const int bin = (b32 >> 24) & 0xFF; + s_remapped[idx] = remapped; + s_bins [idx] = static_cast(bin); + ::atomicAdd(&s_hist[bin], 1); + } + __syncthreads(); - const auto threshold_bin = f_threshold_bin_id; - topk -= f_histogram[threshold_bin + 1]; + run_cumsum_256(s_hist_buf); - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const int bin = static_cast(s_bins[idx]); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&f_counter, 1); - index[pos] = idx; - } + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; } __syncthreads(); - return; - } else { - __syncthreads(); - if (tx < RADIX + 1) f_histogram[tx] = 0; + const int threshold_bin = s_threshold_bin; + + // ---- Emit bin > threshold (warp-popc) and build refinement hist. ---- + if (tx < RADIX + 1) s_hist[tx] = 0; __syncthreads(); - constexpr int sub_bin_offset_start = use_dense_bucket ? 8 : 24; - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - const int bin = static_cast(s_bins[idx]); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&f_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform_tmpl(raw, mapping.power_exp); - const auto pos = ::atomicAdd(&f_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - f_input_idx(0, pos) = idx; - const auto b32 = convert_to_uint32(remapped); - const auto sub_bin = (b32 >> sub_bin_offset_start) & 0xFF; - ::atomicAdd(&f_histogram[sub_bin], 1); - } + const int num_iters = (chunk_size + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + const bool take_above = in_range && (bin > threshold_bin); + + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + chunk_keys_out[slot] = s_remapped[idx]; + chunk_idx_out [slot] = idx + idx_base; + } else if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); } } __syncthreads(); - } - constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; - constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - if (round >= stage2_max_rounds) break; - __shared__ int f_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = f_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input - : int(SMEM_INPUT_SIZE); - run_cumsum(); - if (tx < RADIX && f_histogram[tx] > topk && f_histogram[tx + 1] <= topk) { - f_threshold_bin_id = tx; - f_num_input[r_idx ^ 1] = 0; - f_last_remain = topk - f_histogram[tx + 1]; + // ---- Refinement cumsum → sub-threshold bin. ------------------------ + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + // budget for items at the sub-threshold bin + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + // Only possible if last_remain == 0 (bin > threshold already emitted + // exactly K items). Nothing more to do; make the sub bin a sentinel. + s_sub_threshold_bin = RADIX; // no sub-threshold bin } __syncthreads(); - - const auto threshold_bin = f_threshold_bin_id; - topk -= f_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = f_input_idx(r_idx, i); - const auto offset = stage2_offset_start - round * 8; - const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform_tmpl(raw, mapping.power_exp); - const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&f_counter, 1); - index[pos] = idx; - } + const int sub_threshold_bin = s_sub_threshold_bin; + + // ---- Emit threshold-bin items using sub-threshold logic. ---------- + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + int sub_bin = -1; + if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + sub_bin = (b32 >> 16) & 0xFF; } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) f_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = f_input_idx(r_idx, i); - const float raw = vortex_to_float(input[idx + row_start]); - const float remapped = apply_transform_tmpl(raw, mapping.power_exp); - const auto offset = stage2_offset_start - round * 8; - const auto bin = (convert_to_uint32(remapped) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&f_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == stage2_max_rounds - 1) { - const auto pos = ::atomicAdd(&f_last_remain, -1); - if (pos > 0) index[target_k - pos] = idx; - } else { - const auto pos = ::atomicAdd(&f_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - f_input_idx(r_idx ^ 1, pos) = idx; - const auto b32 = convert_to_uint32(remapped); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&f_histogram[sub_bin], 1); - } - } + + const bool take_sub_above = (sub_bin > sub_threshold_bin); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + chunk_keys_out[slot] = s_remapped[idx]; + chunk_idx_out [slot] = idx + idx_base; + } else if (sub_bin == sub_threshold_bin) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + chunk_keys_out[K - pos] = s_remapped[idx]; + chunk_idx_out [K - pos] = idx + idx_base; } } - __syncthreads(); } + __syncthreads(); } -} - -// ============================================================================ -// fast_topk_merge -// -// Run by the last-arriving CTA of each batch. Input is the combined -// candidate list (`num_splits * topk_val` float keys + int indices, -// with idx==-1 marking sentinel slots). Reuses the same extern-smem -// region `s_input_idx_raw` that Phase 1 used — its earlier contents -// are dead at this point. Output: top-`target_k` positions into -// `index`, indexing the combined candidate list. -// -// Bucketing matches the fused kernel's bucketing for the given MODE -// so the merged top-K is lossless modulo atomic tie-break order. -// ============================================================================ -template -__device__ void fast_topk_merge( - const float* __restrict__ input, - const int* __restrict__ valid_mask, - int* __restrict__ index, - int* __restrict__ s_input_idx_raw, // 2 × SMEM_INPUT_SIZE ints - int row_start, - int length, - int target_k) -{ - int topk = target_k; - constexpr auto BLOCK_SIZE = 1024; - constexpr auto RADIX = 256; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - constexpr bool use_dense_bucket = (MODE == MAPPING_DENSE_MANT); - constexpr int stage2_offset_start = use_dense_bucket ? 8 : 24; - constexpr int stage2_max_rounds = use_dense_bucket ? 2 : 4; - - alignas(128) __shared__ int s_histogram_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin_id; - alignas(128) __shared__ int s_num_input[2]; - - auto& s_histogram = s_histogram_buf[0]; - auto s_input_idx = [&](int bank, int pos) -> int& { - return s_input_idx_raw[bank * SMEM_INPUT_SIZE + pos]; - }; - - const int tx = threadIdx.x; - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - if (valid_mask[idx + row_start] < 0) continue; // sentinel; skip - const float v = input[idx + row_start]; - int bin; - if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(v)); - else bin = static_cast(convert_to_uint8(v)); - ::atomicAdd(&s_histogram[bin], 1); - } + // ========================================================================= + // Barrier: publish this CTA's workspace writes and atomicInc the + // per-batch done-counter. The CTA that sees old == N-1 is the last + // arriving one; every other CTA returns here. + // ========================================================================= + __threadfence(); __syncthreads(); - - const auto run_cumsum = [&] { -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const auto j = 1 << i; - const auto k = i & 1; - auto value = s_histogram_buf[k][tx]; - if (tx < RADIX - j) value += s_histogram_buf[k][tx + j]; - s_histogram_buf[k ^ 1][tx] = value; - } - __syncthreads(); - } - }; - - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[0] = 0; - s_counter = 0; + if (tx == 0) { + const unsigned int old = ::atomicInc( + &g_done_counter[b], static_cast(N - 1)); + s_is_last = (old == static_cast(N - 1)) ? 1 : 0; } __syncthreads(); + if (s_is_last == 0) return; - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - if (valid_mask[idx + row_start] < 0) continue; - const float v = input[idx + row_start]; - int bin; - if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(v)); - else bin = static_cast(convert_to_uint8(v)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } + // ========================================================================= + // Phase 2 (merge, only in last-arriving CTA): + // load N*K candidates into smem (vectorised) → + // 8-bit histogram in smem → + // threshold → warp-popc emit above + tie-bin refinement. + // ========================================================================= + const int total = N * K; + const float* keys_in = partial_keys + static_cast(b) * total; + const int32_t* idx_in = partial_idx + static_cast(b) * total; + int32_t* out_idx = global_idx + static_cast(b) * K; + + // Reuse the same dynamic smem region as Phase 1 — Phase 1's caches + // are dead now. Layout: [ s_scores : total floats | s_indices : total int32 ]. + float* s_scores = reinterpret_cast(smem_raw); + int32_t* s_indices = reinterpret_cast(s_scores + total); + + // Vectorised 128-bit loads when `total` is a multiple of 4. + if ((total & 3) == 0) { + const float4* keys_v = reinterpret_cast(keys_in); + const int4* idx_v = reinterpret_cast (idx_in); + float4* ss_v = reinterpret_cast (s_scores); + int4* si_v = reinterpret_cast (s_indices); + const int total4 = total >> 2; + for (int i = tx; i < total4; i += blockDim.x) { + ss_v[i] = keys_v[i]; + si_v[i] = idx_v [i]; } - __syncthreads(); - return; } else { - __syncthreads(); - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - - for (int idx = tx; idx < length; idx += BLOCK_SIZE) { - if (valid_mask[idx + row_start] < 0) continue; - const auto raw_input = input[idx + row_start]; - int bin; - if constexpr (use_dense_bucket) bin = static_cast(convert_to_uint8_dense(raw_input)); - else bin = static_cast(convert_to_uint8(raw_input)); - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - const auto pos = ::atomicAdd(&s_num_input[0], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx(0, pos) = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> stage2_offset_start) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } + for (int i = tx; i < total; i += blockDim.x) { + s_scores [i] = keys_in[i]; + s_indices[i] = idx_in [i]; } - __syncthreads(); } -#pragma unroll 4 - for (int round = 0; round < 4; ++round) { - if (round >= stage2_max_rounds) break; - __shared__ int s_last_remain; - const auto r_idx = round % 2; - - const auto _raw_num_input = s_num_input[r_idx]; - const auto num_input = (_raw_num_input < int(SMEM_INPUT_SIZE)) ? _raw_num_input - : int(SMEM_INPUT_SIZE); - run_cumsum(); - if (tx < RADIX && s_histogram[tx] > topk && s_histogram[tx + 1] <= topk) { - s_threshold_bin_id = tx; - s_num_input[r_idx ^ 1] = 0; - s_last_remain = topk - s_histogram[tx + 1]; - } - __syncthreads(); + if (tx < RADIX + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + __syncthreads(); - const auto threshold_bin = s_threshold_bin_id; - topk -= s_histogram[threshold_bin + 1]; - - if (topk == 0) { - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx(r_idx, i); - const auto offset = stage2_offset_start - round * 8; - const auto bin = (convert_to_uint32(input[idx + row_start]) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } - } - __syncthreads(); - break; - } else { - __syncthreads(); - if (tx < RADIX + 1) s_histogram[tx] = 0; - __syncthreads(); - for (int i = tx; i < num_input; i += BLOCK_SIZE) { - const auto idx = s_input_idx(r_idx, i); - const auto raw_input = input[idx + row_start]; - const auto offset = stage2_offset_start - round * 8; - const auto bin = (convert_to_uint32(raw_input) >> offset) & 0xFF; - if (bin > threshold_bin) { - const auto pos = ::atomicAdd(&s_counter, 1); - index[pos] = idx; - } else if (bin == threshold_bin) { - if (round == stage2_max_rounds - 1) { - const auto pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) index[target_k - pos] = idx; - } else { - const auto pos = ::atomicAdd(&s_num_input[r_idx ^ 1], 1); - if (C10_LIKELY(pos < SMEM_INPUT_SIZE)) { - s_input_idx(r_idx ^ 1, pos) = idx; - const auto b32 = convert_to_uint32(raw_input); - const auto sub_bin = (b32 >> (offset - 8)) & 0xFF; - ::atomicAdd(&s_histogram[sub_bin], 1); - } - } - } - } - __syncthreads(); + // (2) 8-bit histogram in smem. + const int num_iters_m = (total + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + if (i < total && s_indices[i] >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_hist[bin], 1); } } -} + __syncthreads(); -// ============================================================================ -// Combined kernel. -// -// Grid: (num_splits, eff_batch_size). Every CTA: -// 1. Computes its partition's top-K (fast_topk_partition). -// 2. Writes (remapped key, batch-local idx) pairs + sentinels to the -// per-batch workspace slot. -// 3. __threadfence() to publish the writes, then atomicAdd on the -// per-batch done-counter. The CTA whose atomicAdd returns -// num_splits - 1 is the last one for this batch. -// 4. If last: run the merge (fast_topk_merge) on the combined -// num_splits*topk_val candidates and write final page IDs to -// sparse_kv_indices. Other CTAs exit. -// ============================================================================ -template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopKOutput_Parallel_Kernel( - const ScoreT* __restrict__ score, - const int* __restrict__ dense_kv_indptr, - const int* __restrict__ sparse_kv_indptr, - const int* __restrict__ dense_kv_indices, - int* __restrict__ sparse_kv_indices, - float* __restrict__ partial_keys, // [eff_bs * num_splits * topk_val] - int* __restrict__ partial_idx, // [eff_bs * num_splits * topk_val] - const int topk_val, - const int num_splits, - const int page_reserved_bos, - const int page_reserved_eos, - const int chunk_bytes, // smem bytes reserved for s_bins - const TopKMappingParams mapping) -{ - // ---- Dynamic smem layout ------------------------------------------------- - // [ f_input_idx (2 × SMEM_INPUT_SIZE ints = kSmem bytes) - // s_bins (chunk_bytes, only valid during Phase 1) ] - // The merge doesn't touch s_bins, so its extern region overlaps - // f_input_idx harmlessly. - extern __shared__ int smem_scratch[]; - constexpr auto SMEM_INPUT_SIZE = kSmem / (2 * sizeof(int)); - int* f_input_idx_raw = smem_scratch; - uint8_t* s_bins = reinterpret_cast(&smem_scratch[2 * SMEM_INPUT_SIZE]); - (void)chunk_bytes; // sizing is the host's responsibility; kernel just uses it - - // s_indices doubles as the partition's radix output AND the merge's radix - // output — they run sequentially on the same CTA, so the same ~2K slots - // are reused. Stores up to VORTEX_MAX_TOPK = 2048 entries. - __shared__ int s_indices[VORTEX_MAX_TOPK]; - // Broadcasts whether this CTA is the last-arriving one for its batch. - __shared__ int s_is_last; - - const int p = blockIdx.x; - const int bx = blockIdx.y; - const int tx = threadIdx.x; + run_cumsum_256(s_hist_buf); - const int start = dense_kv_indptr[bx] + page_reserved_bos; - const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; - const int total_len = end - start; - - // Short batch: fused kernel returns without writing; match that. - if (total_len <= topk_val) return; - - const size_t slot_base = (static_cast(bx) * num_splits + p) * topk_val; - float* keys_out = partial_keys + slot_base; - int* idx_out = partial_idx + slot_base; - - const int chunk = (total_len + num_splits - 1) / num_splits; - const int part_start = p * chunk; - const int raw_part_end = part_start + chunk; - const int part_end = raw_part_end < total_len ? raw_part_end : total_len; - const int part_len = (part_end > part_start) ? (part_end - part_start) : 0; - - // Sentinel tail: merge filters these by idx == -1. Only fill the range - // that won't be overwritten with real data. - const int real_fill = (part_len < topk_val) ? part_len : topk_val; - const int tail_count = topk_val - real_fill; - if (tail_count > 0) { - for (int i = tx; i < tail_count; i += blockDim.x) { - keys_out[real_fill + i] = -CUDART_INF_F; - idx_out [real_fill + i] = -1; + // Fast path: no threshold search needed when valid_count ≤ K. + const int valid_count = s_hist[0]; + if (valid_count <= K) { + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + const bool take = (i < total) && (s_indices[i] >= 0); + const int slot = warp_compact_slot(take, &s_counter); + if (take) out_idx[slot] = s_indices[i]; } - __syncthreads(); + return; } - const ScoreT* __restrict__ slice_ptr = score + start + part_start; - - // ---- Phase 1: per-partition top-K --------------------------------------- - if (part_len > 0) { - if (part_len <= topk_val) { - // Whole slice fits under topk_val — emit it directly. - for (int i = tx; i < part_len; i += blockDim.x) { - const float raw = vortex_to_float(slice_ptr[i]); - const float remapped = apply_transform_tmpl(raw, mapping.power_exp); - keys_out[i] = remapped; - idx_out [i] = part_start + i; - } - } else { - fast_topk_partition( - slice_ptr, s_indices, f_input_idx_raw, s_bins, - 0, part_len, topk_val, mapping); - __syncthreads(); - for (int i = tx; i < topk_val; i += blockDim.x) { - const int sl = s_indices[i]; - const float raw = vortex_to_float(slice_ptr[sl]); - const float remapped = apply_transform_tmpl(raw, mapping.power_exp); - keys_out[i] = remapped; - idx_out [i] = part_start + sl; - } - } + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; } + __syncthreads(); + const int threshold_bin_m = s_threshold_bin; - // Publish workspace writes so the last-CTA can observe them. - __threadfence(); + // (3) Emit above threshold via warp-popc; build sub-bin histogram on + // bits [23:16] for the tie-bin refinement. + if (tx < RADIX + 1) s_hist[tx] = 0; __syncthreads(); - // ---- Arrive at the barrier via atomicInc -------------------------------- - // atomicInc(ptr, N-1) stores `((old >= N-1) ? 0 : old+1)` and returns old. - // So with N == num_splits the counter cycles 0→1→…→N-1→0 per call, which - // means we never need to memset done_counter between calls — after the - // last-CTA's increment it's back at 0, ready for the next launch. - // (Relies on the caller allocating done_counter zero-initialised once.) - if (tx == 0) { - const unsigned int old = ::atomicInc( - reinterpret_cast(&g_parallel_done_counter[bx]), - static_cast(num_splits - 1)); - s_is_last = (old == static_cast(num_splits - 1)) ? 1 : 0; + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_valid = false; + int bin = -1; + uint32_t b32 = 0; + if (i < total) { + const int32_t idx = s_indices[i]; + if (idx >= 0) { + in_valid = true; + b32 = convert_to_uint32(s_scores[i]); + bin = (b32 >> 24) & 0xFF; + } + } + const bool take_above = in_valid && (bin > threshold_bin_m); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + out_idx[slot] = s_indices[i]; + } else if (in_valid && bin == threshold_bin_m) { + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } } __syncthreads(); - if (s_is_last == 0) return; - - // ---- Merge: last CTA selects final top-K -------------------------------- - const int candidate_len = num_splits * topk_val; - const size_t batch_base = static_cast(bx) * candidate_len; - const float* keys_blk = partial_keys + batch_base; - const int* idx_blk = partial_idx + batch_base; - int* out_blk = sparse_kv_indices - + sparse_kv_indptr[bx] - + page_reserved_bos; - const int* dense_blk = dense_kv_indices + start; - - fast_topk_merge( - keys_blk, idx_blk, s_indices, f_input_idx_raw, - 0, candidate_len, topk_val); + // (4) Refinement cumsum → sub-threshold bin. + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + s_sub_threshold_bin = RADIX; // no tie-bin refinement needed + } __syncthreads(); - - for (int i = tx; i < topk_val; i += blockDim.x) { - const int pos = s_indices[i]; - const int batch_local = idx_blk[pos]; - out_blk[i] = (batch_local >= 0) ? dense_blk[batch_local] : -1; + const int sub_threshold_bin_m = s_sub_threshold_bin; + + // (5) Emit tie-bin items via warp-popc + sub-threshold budget. + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_threshold = false; + int sub_bin = -1; + if (i < total) { + const int32_t idx = s_indices[i]; + if (idx >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + if (bin == threshold_bin_m) { + in_threshold = true; + sub_bin = (b32 >> 16) & 0xFF; + } + } + } + const bool take_sub_above = in_threshold && (sub_bin > sub_threshold_bin_m); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + out_idx[slot] = s_indices[i]; + } else if (in_threshold && sub_bin == sub_threshold_bin_m) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) out_idx[K - pos] = s_indices[i]; + } } } -// ---- setup_kernel_smem_once (duplicated) ----------------------------------- +// ---- setup_kernel_smem_once ------------------------------------------------ template void setup_kernel_smem_once() { [[maybe_unused]] static const auto result = [] { -#ifdef USE_ROCM - return ::cudaFuncSetAttribute( - reinterpret_cast(f), - ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#else return ::cudaFuncSetAttribute( f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); -#endif }(); TORCH_CHECK(result == cudaSuccess, - "set_up_kernel_once (parallel) failed:", ::cudaGetErrorString(result)); + "fast_fused_topk_merge setup failed: ", + ::cudaGetErrorString(result)); } } // namespace @@ -662,150 +501,139 @@ void setup_kernel_smem_once() { // ============================================================================ // Host entry point. // -// Signature matches topk_output_sglang_fused plus `num_splits`. -// `num_splits <= 1` delegates to the single-CTA fused kernel so callers -// can unconditionally use this path. +// score [batch_size, num_chunks, chunk_size] bf16 or f32 +// global_topk_indices [batch_size, topk_val] int32 (output) +// +// ONE kernel launch. The per-chunk selection (Phase 1) and the +// cross-chunk merge (Phase 2) are fused in TopK_Parallel_Kernel via a +// last-CTA-wins atomicInc barrier. A per-call workspace holds the +// [batch, N, K] partial top-K that the last CTA reads from; the +// done-counter is a program-lifetime __device__ global so nothing +// needs memsetting on the hot path. // ============================================================================ -void topk_output_sglang_parallel( - const at::Tensor& x, - const at::Tensor& dense_kv_indptr, - const at::Tensor& sparse_kv_indptr, - const at::Tensor& dense_kv_indices, - at::Tensor& sparse_kv_indices, - const int64_t eff_batch_size, +void fast_fused_topk_merge( + const at::Tensor& score, + at::Tensor& global_topk_indices, + const int64_t batch_size, + const int64_t num_chunks, + const int64_t chunk_size, const int64_t topk_val, - const int64_t reserved_bos, - const int64_t reserved_eos, - const int64_t max_num_pages, - const int64_t num_splits, const int64_t mapping_mode, - const double mapping_power, - std::optional mapping_lut, - std::optional mapping_quantiles) + const double mapping_power) { - TORCH_CHECK(topk_val <= VORTEX_MAX_TOPK, - "topk_output_sglang_parallel: topk_val (", topk_val, - ") exceeds VORTEX_MAX_TOPK (", VORTEX_MAX_TOPK, ")"); - TORCH_CHECK(num_splits >= 1, - "topk_output_sglang_parallel: num_splits must be >= 1"); - - if (num_splits <= 1) { - topk_output_sglang_fused( - x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, - sparse_kv_indices, eff_batch_size, topk_val, - reserved_bos, reserved_eos, max_num_pages, - mapping_mode, mapping_power, mapping_lut, mapping_quantiles); - return; - } - - CHECK_CUDA(x); - CHECK_CUDA(dense_kv_indptr); - CHECK_CUDA(sparse_kv_indptr); - CHECK_CUDA(dense_kv_indices); - CHECK_CUDA(sparse_kv_indices); - - (void)mapping_lut; - (void)mapping_quantiles; - - TopKMappingParams mapping{}; - mapping.mode = static_cast(mapping_mode); - mapping.power_exp = static_cast(mapping_power); - mapping.lut = nullptr; - mapping.quantiles = nullptr; - - // Dynamic smem = kSmem (f_input_idx) + chunk_bytes (s_bins for the - // partition radix; the merge doesn't touch s_bins). - const int64_t chunk_pages = (max_num_pages + num_splits - 1) / num_splits; - const size_t chunk_bytes = (static_cast(chunk_pages) + size_t(15)) & ~size_t(15); - const size_t smem_bytes = kSmem + chunk_bytes; - TORCH_CHECK(smem_bytes <= kFusedSmemMax, - "topk_output_sglang_parallel: smem ", smem_bytes, - " exceeds ceiling ", kFusedSmemMax); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - TORCH_CHECK(eff_batch_size <= kMaxParallelEffBs, - "topk_output_sglang_parallel: eff_batch_size (", eff_batch_size, - ") exceeds kMaxParallelEffBs (", kMaxParallelEffBs, - "). Raise the __device__ counter array size."); - - // Per-call workspace. at::empty, no zero-init — kernel fills every used - // slot (valid prefix + sentinel tail). done_counter is a __device__ - // global (above) so no workspace allocation needed for it. - const int64_t ws_elems = eff_batch_size * num_splits * topk_val; - auto opts_f32 = at::TensorOptions().device(x.device()).dtype(at::kFloat); - auto opts_i32 = at::TensorOptions().device(x.device()).dtype(at::kInt); - at::Tensor partial_keys = at::empty({ws_elems}, opts_f32); - at::Tensor partial_idx = at::empty({ws_elems}, opts_i32); - - dim3 grid(static_cast(num_splits), - static_cast(eff_batch_size)); - dim3 nthreads(kThreadsPerBlock); - - #define VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MODE_VAL) \ - do { \ - setup_kernel_smem_once< \ - TopKOutput_Parallel_Kernel, \ - kFusedSmemMax>(); \ - TopKOutput_Parallel_Kernel \ - <<>>( \ - PTR_EXPR, \ - dense_kv_indptr.data_ptr(), \ - sparse_kv_indptr.data_ptr(), \ - dense_kv_indices.data_ptr(), \ - sparse_kv_indices.data_ptr(), \ - partial_keys.data_ptr(), \ - partial_idx.data_ptr(), \ - static_cast(topk_val), \ - static_cast(num_splits), \ - static_cast(reserved_bos), \ - static_cast(reserved_eos), \ - static_cast(chunk_bytes), \ - mapping); \ - } while (0) - - #define VORTEX_PARALLEL_DISPATCH_MODE(DTYPE, PTR_EXPR) \ - do { \ - switch (mapping.mode) { \ - case MAPPING_NONE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ - case MAPPING_POWER: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ - case MAPPING_LOG: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ - case MAPPING_ASINH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ - case MAPPING_LOG1P: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ - case MAPPING_TRUNC8: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_TRUNC8); break; \ - case MAPPING_ERF: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ - case MAPPING_TANH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ - case MAPPING_SUBTRACT: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ - case MAPPING_EXP_STRETCH: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ - case MAPPING_SHIFT_POW2: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ - case MAPPING_SHIFT_POW3: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ - case MAPPING_LINEAR_STEEP:VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ - case MAPPING_HALF_SQUARE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ - case MAPPING_HALF_CUBE: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ - case MAPPING_DENSE_MANT: VORTEX_PARALLEL_DISPATCH(DTYPE, PTR_EXPR, MAPPING_DENSE_MANT); break; \ - default: \ - TORCH_CHECK(false, \ - "topk_output_sglang_parallel: unsupported mapping_mode ", \ - mapping.mode); \ - } \ - } while (0) - - if (x.scalar_type() == at::ScalarType::BFloat16) { - VORTEX_PARALLEL_DISPATCH_MODE( - __nv_bfloat16, - reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); - } else if (x.scalar_type() == at::ScalarType::Float) { - VORTEX_PARALLEL_DISPATCH_MODE(float, x.data_ptr()); - } else { - TORCH_CHECK(false, "topk_output_sglang_parallel: unsupported dtype ", - x.scalar_type()); - } + CHECK_CUDA(score); + CHECK_CUDA(global_topk_indices); + + TORCH_CHECK(topk_val > 0 && topk_val <= VORTEX_MAX_TOPK, + "fast_fused_topk_merge: topk_val=", topk_val, + " must be in (0, ", VORTEX_MAX_TOPK, "]"); + TORCH_CHECK(num_chunks >= 1, "num_chunks must be >= 1"); + TORCH_CHECK(batch_size >= 1, "batch_size must be >= 1"); + TORCH_CHECK(batch_size <= kMaxBatch, + "fast_fused_topk_merge: batch_size ", batch_size, + " exceeds the __device__ done-counter cap (", kMaxBatch, ")"); + TORCH_CHECK(chunk_size >= 1, "chunk_size must be >= 1"); + TORCH_CHECK(num_chunks * topk_val <= kMergeCap, + "fast_fused_topk_merge: num_chunks*topk_val (", + num_chunks * topk_val, ") exceeds merge cap (", kMergeCap, + "). Reduce num_chunks or topk_val."); + TORCH_CHECK(global_topk_indices.scalar_type() == at::kInt, + "global_topk_indices must be int32"); + TORCH_CHECK(global_topk_indices.numel() >= batch_size * topk_val, + "global_topk_indices is too small for batch_size * topk_val"); + + TORCH_CHECK( + mapping_mode == MAPPING_POWER || + mapping_mode == MAPPING_ASINH || + mapping_mode == MAPPING_LOG1P || + mapping_mode == MAPPING_ERF || + mapping_mode == MAPPING_TANH || + mapping_mode == MAPPING_SUBTRACT || + mapping_mode == MAPPING_EXP_STRETCH || + mapping_mode == MAPPING_SHIFT_POW2 || + mapping_mode == MAPPING_SHIFT_POW3 || + mapping_mode == MAPPING_LINEAR_STEEP, + "fast_fused_topk_merge: mapping_mode=", mapping_mode, + " not supported. Valid: POWER(3), ASINH(6), LOG1P(7), ERF(9), " + "TANH(10), SUBTRACT(11), EXP_STRETCH(13), SHIFT_POW2(15), " + "SHIFT_POW3(16), LINEAR_STEEP(17)."); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + // Dynamic smem must fit whichever phase is larger: + // Phase 1: chunk_size floats + chunk_size bytes. + // Phase 2: num_chunks*topk_val * (float + int32). + const size_t p1_bytes = static_cast(chunk_size) * sizeof(float) + + ((static_cast(chunk_size) + 15) & ~size_t(15)); + const size_t p2_bytes = static_cast(num_chunks) * + static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + const size_t smem_bytes = p1_bytes > p2_bytes ? p1_bytes : p2_bytes; + TORCH_CHECK(smem_bytes <= kMaxDynSmem, + "fast_fused_topk_merge: smem ", smem_bytes, + " > ceiling ", kMaxDynSmem); + + // Per-call workspace for the [batch, N, K] partial top-K. at::empty + // hits the caching allocator (no cudaMalloc in the hot path after + // warmup). The done-counter lives in __device__ memory — no memset. + auto opts_f32 = at::TensorOptions().device(score.device()).dtype(at::kFloat); + auto opts_i32 = at::TensorOptions().device(score.device()).dtype(at::kInt); + const int64_t ws_elems = batch_size * num_chunks * topk_val; + at::Tensor partial_keys = at::empty({ws_elems}, opts_f32); + at::Tensor partial_idx = at::empty({ws_elems}, opts_i32); + + dim3 grid(static_cast(batch_size), + static_cast(num_chunks)); + dim3 block(kThreadsPerBlock); + + #define LAUNCH(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, \ + kMaxDynSmem>(); \ + TopK_Parallel_Kernel \ + <<>>( \ + PTR_EXPR, \ + global_topk_indices.data_ptr(), \ + partial_keys.data_ptr(), \ + partial_idx.data_ptr(), \ + static_cast(num_chunks), \ + static_cast(chunk_size), \ + static_cast(topk_val), \ + mp); \ + } while (0) + + #define DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping_mode) { \ + case MAPPING_POWER: LAUNCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_ASINH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + default: TORCH_CHECK(false, "unreachable mode"); \ + } \ + } while (0) + + if (score.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MODE(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(score.data_ptr())); + } else if (score.scalar_type() == at::ScalarType::Float) { + DISPATCH_MODE(float, score.data_ptr()); + } else { + TORCH_CHECK(false, "fast_fused_topk_merge: unsupported dtype ", + score.scalar_type()); + } - #undef VORTEX_PARALLEL_DISPATCH_MODE - #undef VORTEX_PARALLEL_DISPATCH + #undef DISPATCH_MODE + #undef LAUNCH - const auto result = cudaGetLastError(); - TORCH_CHECK(result == cudaSuccess, - "topk_output_sglang_parallel kernel failed: ", - ::cudaGetErrorString(result)); + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "fast_fused_topk_merge kernel failed: ", ::cudaGetErrorString(rc)); } diff --git a/csrc/topk_sglang_profile.cu b/csrc/topk_sglang_profile.cu index 7fe99814..af6763a2 100644 --- a/csrc/topk_sglang_profile.cu +++ b/csrc/topk_sglang_profile.cu @@ -177,16 +177,9 @@ __device__ void fast_topk_profile( // Mirror of the production kernel: MAPPING_DENSE_MANT bypasses // apply_transform and uses a mantissa-heavy fp32 bit slice for the // Stage-1 bucket. - const bool use_dense_bucket = (mapping.mode == MAPPING_DENSE_MANT); - - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } + // MAPPING_DENSE_MANT / MAPPING_LUT_CDF / MAPPING_QUANTILE have been + // retired; every mode uses the standard fp16 bucket. + const bool use_dense_bucket = false; if (tx < RADIX + 1) p_histogram[tx] = 0; __syncthreads(); @@ -434,19 +427,11 @@ void TopKProfileHistogram_Kernel( const int end = dense_kv_indptr[bx + 1] - page_reserved_eos; const int nblk = end - start; - if (mapping.mode == MAPPING_LUT_CDF && mapping.lut != nullptr) { - if (tx < 256) s_mapping_lut[tx] = mapping.lut[tx]; - __syncthreads(); - } - if (mapping.mode == MAPPING_QUANTILE && mapping.quantiles != nullptr) { - if (tx < 256) s_mapping_quantiles[tx] = mapping.quantiles[tx]; - __syncthreads(); - } - if (tx < RADIX) s_histogram[tx] = 0; __syncthreads(); - const bool use_dense_bucket = (mapping.mode == MAPPING_DENSE_MANT); + // MAPPING_DENSE_MANT / MAPPING_LUT_CDF / MAPPING_QUANTILE retired. + const bool use_dense_bucket = false; if (nblk > 0) { const ScoreT* __restrict__ score_blk = score + start; for (int i = tx; i < nblk; i += BLOCK_SIZE) { diff --git a/examples/remap_function_bench_topk_parallel.sh b/examples/remap_function_bench_topk_parallel.sh index df33f2cf..4a4e4c57 100755 --- a/examples/remap_function_bench_topk_parallel.sh +++ b/examples/remap_function_bench_topk_parallel.sh @@ -26,7 +26,7 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" # ── Defaults ────────────────────────────────────────────────── -GPU_ID=4 +GPU_ID=7 MODEL_NAME="Qwen/Qwen3-1.7B" TOPK_VAL=2048 MEM=0.7 diff --git a/setup.py b/setup.py index 8b496610..9ff56088 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ 'csrc/topk_sglang_profile.cu', 'csrc/topk_sglang_ori.cu', 'csrc/topk_sglang_parallel.cu', + 'csrc/topk_sglang_cluster.cu', ], include_dirs=['csrc'], extra_compile_args={ From fb437447ff9f1b686af08f1c732530218df61e02 Mon Sep 17 00:00:00 2001 From: UED Date: Sun, 26 Apr 2026 16:00:03 -0400 Subject: [PATCH 24/24] Update TopK kernel implementations and benchmarking scripts - Replaced outdated TopK kernel files with new adaptive split and workspace ablation kernels. - Introduced new benchmarking scripts for comprehensive performance analysis of the adaptive split kernel. - Enhanced existing scripts to support new configurations and profiling metrics for the adaptive TopK implementation. - Updated setup.py to include new source files and removed deprecated kernel references. - Improved documentation and example scripts to reflect changes in benchmarking methodologies and configurations. --- benchmarks/bench_ablation.py | 341 +++ benchmarks/bench_midk_fused_baseline.py | 128 ++ benchmarks/bench_topk.py | 139 +- benchmarks/bench_topk_setting_sweep.py | 1211 ++++++++++ benchmarks/profile_adaptive_overhead.py | 212 ++ csrc/{ => archived}/topk_sglang_cluster.cu | 0 csrc/archived/topk_sglang_parallel.cu | 639 ++++++ csrc/register.cc | 65 +- csrc/register.h | 108 +- csrc/topk_adaptive_profile.cu | 1145 ++++++++++ csrc/topk_sglang_merge.cu | 1939 +++++++++++++++++ csrc/topk_sglang_parallel.cu | 639 ------ .../run_distribution_analysis.sh | 0 .../verify_algo_topk_mapping.sh | 0 examples/plot_parallel_comparison.py | 384 ++++ examples/profile_in_docker.sh | 181 -- examples/profile_parallel_vs_fused_ncu.sh | 277 --- examples/profile_parallel_vs_fused_nsys.sh | 211 -- .../remap_function_bench_topk_parallel.sh | 327 +-- examples/test_topk.py | 118 - setup.py | 4 +- 21 files changed, 6308 insertions(+), 1760 deletions(-) create mode 100644 benchmarks/bench_ablation.py create mode 100644 benchmarks/bench_midk_fused_baseline.py create mode 100644 benchmarks/bench_topk_setting_sweep.py create mode 100644 benchmarks/profile_adaptive_overhead.py rename csrc/{ => archived}/topk_sglang_cluster.cu (100%) create mode 100644 csrc/archived/topk_sglang_parallel.cu create mode 100644 csrc/topk_adaptive_profile.cu create mode 100644 csrc/topk_sglang_merge.cu delete mode 100644 csrc/topk_sglang_parallel.cu rename examples/{ => archived}/run_distribution_analysis.sh (100%) rename examples/{ => archived}/verify_algo_topk_mapping.sh (100%) create mode 100644 examples/plot_parallel_comparison.py delete mode 100755 examples/profile_in_docker.sh delete mode 100755 examples/profile_parallel_vs_fused_ncu.sh delete mode 100755 examples/profile_parallel_vs_fused_nsys.sh delete mode 100644 examples/test_topk.py diff --git a/benchmarks/bench_ablation.py b/benchmarks/bench_ablation.py new file mode 100644 index 00000000..53dbcbb4 --- /dev/null +++ b/benchmarks/bench_ablation.py @@ -0,0 +1,341 @@ +"""Phase + merge ablation for the K=30 random-split parallel kernel. + +Splits production latency into per-phase pieces and compares merge variants +on identical pre-filled workspaces. The fixture lives in +csrc/topk_adaptive_profile.cu (NOT in topk_sglang_merge.cu); the production +kernel still uses the SPLITS-specialised merge described in +csrc/topk_sglang_merge.cu's file header. + +Ablation modes (must match the kAblMode_* constants in topk_adaptive_profile.cu): + + 0 full_parallel (re-enters the production workspace API) + 1 local_only (Stage 1 sort + workspace write only) + 2 local_no_workspace (Stage 1 sort, scratch sink — no ws write) + 3 workspace_write_only (write 32 dummy entries / split) + 4 atomic_only (done_counter atomic + last-CTA test only) + 5 merge_prod_default (legacy per-SPLITS dispatch: 2-way/pairwise/k-way) + 6 merge_only_cub_warp (cub::WarpMergeSort — current production merge) + 7 merge_only_cub_block (cub::BlockMergeSort benchmark) + 8 memset_only (host cudaMemsetAsync of done_counter) + 9 merge_only_2way_manual (SPLITS=2 only) + 10 merge_only_pairwise_tree_4(SPLITS=4 only) + 11 merge_kway_all (force k-way for all SPLITS) + +Benchmark matrix (default; override on the CLI): + B ∈ {1, 2, 4, 8, 16, 32, 128} + pages ∈ {8192, 16384, 32768} + topk_val = 30 + partition = contiguous + forced_splits ∈ {2, 4, 8, 16, 32} + +Outputs `bench_results/k30_ablation.csv` (long-form per-row records) and a +wide table `…_summary.csv` with the columns the spec asks for: + B, pages, split, merge_mode, full_adaptive_us, local_only_us, + workspace_write_us, atomic_only_us, merge_only_us, fused_us, + speedup_vs_fused. +""" + +from __future__ import annotations + +import argparse +import csv +import statistics +import time +from collections import defaultdict +from pathlib import Path + +import torch +import vortex_torch_C as C + + +MODES = [ + (0, "full_parallel"), + (1, "local_only"), + (2, "local_no_workspace"), + (3, "workspace_write_only"), + (4, "atomic_only"), + (5, "merge_prod_default"), # legacy: 2-way/pairwise/k-way per SPLITS + (6, "merge_only_cub_warp"), # current production (WarpMergeSort) + (7, "merge_only_cub_block"), + (8, "memset_only"), + (9, "merge_only_2way_manual"), # SPLITS=2 only + (10, "merge_only_pairwise_tree_4"), # SPLITS=4 only + (11, "merge_kway_all"), # force k-way for all SPLITS +] + + +# ---------- input setup ------------------------------------------------------- + +def make_inputs(eff_bs: int, pages: int, topk_val: int = 30, + bos: int = 0, eos: int = 0, seed: int = 0, + dtype: torch.dtype = torch.bfloat16): + torch.manual_seed(seed) + device = "cuda" + x = torch.randn(eff_bs * pages, dtype=dtype, device=device) + dense_kv_indptr = torch.arange(eff_bs + 1, dtype=torch.int32, device=device) * pages + dense_kv_indices = torch.arange(eff_bs * pages, dtype=torch.int32, device=device) + out_per_row = bos + eos + topk_val + sparse_kv_indptr = torch.arange(eff_bs + 1, dtype=torch.int32, device=device) * out_per_row + sparse_kv_indices = torch.full((eff_bs * out_per_row,), -1, + dtype=torch.int32, device=device) + return { + "x": x, + "dense_kv_indptr": dense_kv_indptr, + "sparse_kv_indptr": sparse_kv_indptr, + "dense_kv_indices": dense_kv_indices, + "sparse_kv_indices": sparse_kv_indices, + } + + +def make_workspace(eff_bs: int): + opts = dict(dtype=torch.int32, device="cuda") + n = eff_bs * 32 * 32 # max splits=32, local_k=32 + return { + "partial_keys": torch.empty(n, **opts), + "partial_indices": torch.empty(n, **opts), + "done_counter": torch.empty(eff_bs, **opts), + "scratch": torch.empty(eff_bs * 32, **opts), + } + + +def fill_workspace_for_merge(ws, eff_bs, splits, seed=1): + """Pre-fill partial_keys/indices with sorted top-32 lists per split. + + Production layout: `[B, SPLITS, 32]` flattened to a 1-D int32 tensor. + Each (b, split) slot is sorted descending by uint32 key. Indices are + distinct global page IDs (no -1 sentinels in the prefilled portion). + """ + torch.manual_seed(seed) + n = eff_bs * splits * 32 + keys_base = torch.randint(0, 2**31 - 1, (eff_bs * splits, 32), + dtype=torch.int64, device="cuda").to(torch.int32) + keys_sorted = keys_base.sort(dim=1, descending=True).values + ws["partial_keys"][:n] = keys_sorted.flatten() + indices = torch.arange(n, dtype=torch.int32, device="cuda") + ws["partial_indices"][:n] = indices + + +# ---------- kernel calls ------------------------------------------------------ + +def call_ablation(inputs, ws, eff_bs, pages, topk_val, mode, splits, + bos=0, eos=0): + inputs["sparse_kv_indices"].fill_(-1) + C.topk_output_adaptive_workspace_ablation( + inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + ws["partial_keys"], ws["partial_indices"], + ws["done_counter"], ws["scratch"], + eff_bs, topk_val, bos, eos, pages, + mode, splits, + ) + + +def call_fused(inputs, eff_bs, pages, topk_val, bos=0, eos=0): + inputs["sparse_kv_indices"].fill_(-1) + C.topk_output_sglang_fused( + inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + eff_bs, topk_val, bos, eos, pages, 0, 0.0, None, None, + ) + + +def bench(fn, *args, warmup=20, repeat=200): + for _ in range(warmup): + fn(*args) + torch.cuda.synchronize() + samples = [] + for _ in range(repeat): + torch.cuda.synchronize() + t0 = time.perf_counter() + fn(*args) + torch.cuda.synchronize() + samples.append((time.perf_counter() - t0) * 1e3) # ms + samples.sort() + return { + "mean": statistics.mean(samples), + "p50": samples[len(samples) // 2], + "p90": samples[int(len(samples) * 0.9)], + "min": samples[0], + "max": samples[-1], + } + + +# ---------- correctness check ------------------------------------------------- + +def verify_merge_only(ws, eff_bs, splits, topk_val, bos=0): + """Check that the merge_only_prod_default kernel returns the true top-K. + + Builds a reference by reading partial_keys/indices into Python, picking + the largest topk_val keys per row, and comparing against the kernel's + output as a SET (production merge order is unspecified for ties). + """ + inputs = make_inputs(eff_bs, pages=8192, topk_val=topk_val, bos=bos) + fill_workspace_for_merge(ws, eff_bs, splits, seed=42) + inputs["sparse_kv_indices"].fill_(-1) + C.topk_output_adaptive_workspace_ablation( + inputs["x"], inputs["dense_kv_indptr"], inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], inputs["sparse_kv_indices"], + ws["partial_keys"], ws["partial_indices"], + ws["done_counter"], ws["scratch"], + eff_bs, topk_val, bos, 0, 8192, + 11, splits, # mode 11 = prod_default + ) + torch.cuda.synchronize() + + # Reference top-K from the prefilled workspace. + n = eff_bs * splits * 32 + keys = ws["partial_keys"][:n].view(eff_bs, splits * 32).to(torch.int64) & 0xFFFFFFFF + idx = ws["partial_indices"][:n].view(eff_bs, splits * 32) + out_per_row = bos + topk_val + out = inputs["sparse_kv_indices"] + + failures = 0 + for b in range(eff_bs): + ref_topk = keys[b].topk(topk_val).indices # local positions + ref_set = set(idx[b, ref_topk].tolist()) + got = out[b * out_per_row + bos : b * out_per_row + bos + topk_val] + got_set = set(got.tolist()) - {-1} + if ref_set != got_set: + failures += 1 + if failures <= 3: + print(f" MERGE CORRECTNESS FAIL b={b} splits={splits} K={topk_val}: " + f"|sym_diff|={len(ref_set ^ got_set)}") + return failures == 0 + + +# ---------- main -------------------------------------------------------------- + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--out", default="bench_results/k30_ablation.csv") + ap.add_argument("--summary-out", default="bench_results/k30_ablation_summary.csv") + ap.add_argument("--warmup", type=int, default=20) + ap.add_argument("--repeat", type=int, default=200) + ap.add_argument("--pages", type=int, nargs="+", default=[8192, 16384, 32768]) + ap.add_argument("--bs", type=int, nargs="+", default=[1, 2, 4, 8, 16, 32, 128]) + ap.add_argument("--splits", type=int, nargs="+", default=[2, 4, 8, 16, 32]) + ap.add_argument("--skip-correctness", action="store_true") + args = ap.parse_args() + + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + + # -- correctness gate first -- + if not args.skip_correctness: + print("=== Merge correctness check (mode=11 merge_kway_all) ===") + ws_check = make_workspace(max(args.bs)) + all_ok = True + for splits in (2, 4, 8, 16, 32): + for K in (1, 4, 8, 16, 30, 32): + ok = verify_merge_only(ws_check, eff_bs=4, splits=splits, topk_val=K) + tag = "OK " if ok else "FAIL" + print(f" splits={splits:2d} K={K:2d} : {tag}") + all_ok &= ok + # reserved_bos cover + ok = verify_merge_only(ws_check, eff_bs=4, splits=splits, topk_val=30, bos=2) + print(f" splits={splits:2d} K=30 bos=2: {'OK ' if ok else 'FAIL'}") + all_ok &= ok + if not all_ok: + print("CORRECTNESS FAILURES — aborting bench") + return 1 + print("All merge-only correctness checks passed.\n") + + long_rows = [] + # cell -> {ablation_name: mean_ms} + cells = defaultdict(dict) + + for pages in args.pages: + for B in args.bs: + inputs = make_inputs(B, pages) + ws = make_workspace(B) + + # Reference: fused. + call_fused(inputs, B, pages, 30) + torch.cuda.synchronize() + s_fused = bench(call_fused, inputs, B, pages, 30, + warmup=args.warmup, repeat=args.repeat) + long_rows.append({ + "pages": pages, "B": B, "splits": 0, + "ablation": "fused_baseline", + **{k: f"{v:.4f}" for k, v in s_fused.items()}, + }) + print(f"\n=== pages={pages} B={B} === fused = {s_fused['mean']*1000:.2f} us") + + for splits in args.splits: + # Pre-fill workspace ahead of the merge-only modes. + fill_workspace_for_merge(ws, B, splits) + torch.cuda.synchronize() + + for mode_id, mode_name in MODES: + if mode_id == 9 and splits != 2: continue + if mode_id == 10 and splits != 4: continue + # Re-prefill before merge-only calls so input layout is fresh. + if mode_id in (5, 6, 7, 9, 10, 11): + fill_workspace_for_merge(ws, B, splits) + torch.cuda.synchronize() + try: + call_ablation(inputs, ws, B, pages, 30, mode_id, splits) + torch.cuda.synchronize() + except RuntimeError as e: + print(f" split={splits} {mode_name}: SKIP ({e})") + continue + stats = bench(call_ablation, inputs, ws, B, pages, 30, + mode_id, splits, + warmup=args.warmup, repeat=args.repeat) + long_rows.append({ + "pages": pages, "B": B, "splits": splits, + "ablation": mode_name, + **{k: f"{v:.4f}" for k, v in stats.items()}, + }) + cells[(pages, B, splits)][mode_name] = stats["mean"] + cells[(pages, B, splits)]["__fused"] = s_fused["mean"] + pct = stats["mean"] / s_fused["mean"] * 100 + print(f" split={splits:2d} {mode_name:<28s}" + f" mean={stats['mean']*1000:7.2f} us" + f" ({pct:5.1f}% of fused)") + + # ---------- write long-form CSV ---------- + long_cols = ["pages", "B", "splits", "ablation", "mean", "p50", "p90", "min", "max"] + with open(out_path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=long_cols) + w.writeheader() + w.writerows(long_rows) + print(f"\nlong-form rows → {out_path} ({len(long_rows)} rows)") + + # ---------- write spec-shaped summary ---------- + summary_path = Path(args.summary_out) + summary_cols = ["B", "pages", "split", "merge_mode", + "full_adaptive_us", "local_only_us", "workspace_write_us", + "atomic_only_us", "merge_only_us", "fused_us", + "speedup_vs_fused"] + + def _us(ms): + return f"{ms * 1000:.2f}" if ms is not None else "" + + with open(summary_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(summary_cols) + for (pages, B, splits), data in sorted(cells.items()): + full = data.get("full_parallel") + local = data.get("local_only") + ws_write = data.get("workspace_write_only") + atomic = data.get("atomic_only") + fused = data.get("__fused") + for merge_name in ("merge_prod_default", "merge_only_cub_warp", + "merge_only_cub_block", "merge_kway_all", + "merge_only_2way_manual", + "merge_only_pairwise_tree_4"): + merge_t = data.get(merge_name) + if merge_t is None: continue + speedup = (fused / full) if (full and fused) else float("nan") + w.writerow([B, pages, splits, merge_name, + _us(full), _us(local), _us(ws_write), + _us(atomic), _us(merge_t), _us(fused), + f"{speedup:.3f}" if speedup == speedup else ""]) + print(f"summary table → {summary_path}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarks/bench_midk_fused_baseline.py b/benchmarks/bench_midk_fused_baseline.py new file mode 100644 index 00000000..c49080d1 --- /dev/null +++ b/benchmarks/bench_midk_fused_baseline.py @@ -0,0 +1,128 @@ +"""Quick fused baseline measurement at mid-K (K in {64,128,256,512}). + +Goal: establish the bar that any adaptive split implementation has to beat +before we commit to building / templating SELECTK_SORTK kernels. + +Output: bench_results/midk_fused_baseline.csv + + a printed table per K. +""" +from __future__ import annotations + +import argparse +import csv +import math +import os +from pathlib import Path + +import torch +import vortex_torch_C as V + + +def time_kernel_us(fn, warmup=10, repeat=100): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + starts = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + for i in range(repeat): + starts[i].record(); fn(); ends[i].record() + torch.cuda.synchronize() + times = sorted(starts[i].elapsed_time(ends[i]) * 1000.0 for i in range(repeat)) + n = len(times) + mean = sum(times) / n + var = sum((t - mean) ** 2 for t in times) / n + return dict(mean=mean, p50=times[n // 2], p90=times[min(n - 1, int(round(n * 0.9)))], + min=times[0], max=times[-1], std=math.sqrt(var)) + + +def make_inputs(B, pages, K, dtype=torch.bfloat16, reserved_bos=1, reserved_eos=2): + device = torch.device("cuda") + dense_kv_indptr = torch.arange(B + 1, device=device, dtype=torch.int32) * pages + sparse_kv_indptr = torch.arange(B + 1, device=device, dtype=torch.int32) * (K + reserved_bos + reserved_eos) + total = B * pages + torch.manual_seed(0) + scores = torch.randn(total, device=device, dtype=dtype) + dense_kv_indices = torch.arange(total, device=device, dtype=torch.int32) + out = torch.full((B * (K + reserved_bos + reserved_eos),), -1, device=device, dtype=torch.int32) + return scores, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, out + + +def call_fused(scores, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, out, + B, K, reserved_bos, reserved_eos, pages, mapping_mode): + V.topk_output_sglang_fused( + scores, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, out, + B, K, reserved_bos, reserved_eos, pages, + mapping_mode, 0.5, None, None, + ) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--out", default="bench_results/midk_fused_baseline.csv") + ap.add_argument("--pages", nargs="+", type=int, default=[16384, 32768, 65536, 131072]) + ap.add_argument("--ks", nargs="+", type=int, default=[64, 128, 256, 512]) + ap.add_argument("--batches", nargs="+", type=int, default=[1, 2, 4, 8, 16]) + ap.add_argument("--mappings", nargs="+", type=int, default=[0, 8]) # NONE, TRUNC8 + ap.add_argument("--warmup", type=int, default=10) + ap.add_argument("--repeat", type=int, default=100) + args = ap.parse_args() + + out_path = Path(args.out) + out_path.parent.mkdir(parents=True, exist_ok=True) + repo_root = Path(__file__).resolve().parents[1] + if not out_path.is_absolute(): + out_path = repo_root / out_path + + device = torch.cuda.get_device_properties(0) + print(f"# GPU: {device.name}, SMs={device.multi_processor_count}") + print(f"# pages: {args.pages}") + print(f"# Ks: {args.ks}") + print(f"# Bs: {args.batches}") + print(f"# maps: {args.mappings} (0=NONE, 8=TRUNC8)") + print() + + rows = [] + for K in args.ks: + print(f"=== K={K} ===") + print(f"{'pages':>8s} {'B':>3s} {'map':>5s} {'mean_us':>10s} {'p50_us':>10s} " + f"{'min_us':>10s} {'std_us':>8s} status") + for pages in args.pages: + for B in args.batches: + for mapping in args.mappings: + map_name = {0: "NONE", 8: "TRUNC8"}.get(mapping, str(mapping)) + try: + ins = make_inputs(B, pages, K) + # warmup correctness check + call_fused(*ins, B, K, 1, 2, pages, mapping) + torch.cuda.synchronize() + except Exception as e: + print(f"{pages:>8d} {B:>3d} {map_name:>5s} {'-':>10s} {'-':>10s} " + f"{'-':>10s} {'-':>8s} FAILED: {str(e)[:80]}") + rows.append(dict(K=K, pages=pages, B=B, mapping=map_name, + mean_us=None, p50_us=None, p90_us=None, + min_us=None, max_us=None, std_us=None, + status="failed", error=str(e)[:200])) + continue + t = time_kernel_us( + lambda: call_fused(*ins, B, K, 1, 2, pages, mapping), + warmup=args.warmup, repeat=args.repeat, + ) + print(f"{pages:>8d} {B:>3d} {map_name:>5s} {t['mean']:>10.3f} " + f"{t['p50']:>10.3f} {t['min']:>10.3f} {t['std']:>8.3f} ok") + rows.append(dict(K=K, pages=pages, B=B, mapping=map_name, + mean_us=t['mean'], p50_us=t['p50'], p90_us=t['p90'], + min_us=t['min'], max_us=t['max'], std_us=t['std'], + status="ok", error="")) + del ins + print() + + with open(out_path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=list(rows[0].keys())) + w.writeheader() + for r in rows: + w.writerow(r) + print(f"# wrote {out_path}") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_topk.py b/benchmarks/bench_topk.py index 8198d0c2..68a4c956 100644 --- a/benchmarks/bench_topk.py +++ b/benchmarks/bench_topk.py @@ -27,8 +27,7 @@ topk_output_sglang, # 2-stage radix approximate topk (unmapped baseline) topk_output_sglang_fused, # fused remap + 2-stage radix topk topk_output_sglang_ori, # original SGLang reference kernel - fast_fused_topk_merge, # single-kernel split+merge (new parallel kernel) - fast_cluster_topk_merge, # Hopper TBC+DSMEM fused split+merge (sm_90+) + topk_output_adaptive, # adaptive split-2 last-CTA-wins (hybrid radix/CUB) topk_remap_only, # standalone value-space remap topk_profile_histogram, topk_profile_counters, @@ -571,6 +570,28 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, none_stats = _collect_threshold_stats( inputs, topk_val, pages_per_seg, args, mode=0, power=0.5 ) + + # Adaptive split-2 kernel (last-CTA-wins merge). Enabled via --bench-parallel. + # For the None row we run it with mapping_mode=0 (identity transform). + none_parallel_ms = None + none_parallel_splits = None + none_cluster_ms = None + if getattr(args, "bench_parallel", False): + par_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + 0, # mapping_mode = NONE + 0.5, # mapping_power (unused for mode 0) + ) + inputs["sparse_kv_indices"].zero_() + par_none = bench_kernel(topk_output_adaptive, par_args, args.warmup, args.repeat) + none_parallel_ms = par_none["mean_ms"] + none_parallel_splits = 2 + config["modes"].append({ "mode": 0, "mode_name": "None", @@ -579,8 +600,9 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, "topk_after_remap_ms": baseline["mean_ms"], "split_total_ms": None, "fused_ms": None, - "parallel_ms": None, - "parallel_splits": None, + "parallel_ms": none_parallel_ms, + "parallel_splits": none_parallel_splits, + "cluster_ms": none_cluster_ms, **none_stats, }) @@ -637,100 +659,28 @@ def _remap_bench_one_config(args, batch_size, num_kv_heads, seq_len, topk_val, inputs["sparse_kv_indices"].zero_() fused = bench_kernel(topk_output_sglang_fused, fused_args, args.warmup, args.repeat) - # New single-kernel split+merge (fast_fused_topk_merge). Takes a - # dense [B, N, chunk] score tensor and writes [B, K] int32 - # indices. We reshape the bench's [eff_bs * pages_per_seg] flat - # scores into [eff_bs, num_splits, chunk_per_split] and compare - # the resulting batch-local indices against the fused kernel's - # (which map through the identity dense_kv_indices). + # Adaptive split-2 kernel (last-CTA-wins merge) with remap mode. + # Only ARITHMETIC_MODES are supported by topk_output_adaptive — the + # LUT/quantile/trunc8 modes have no apply_transform arithmetic path. parallel_ms = None parallel_splits_used = None cluster_ms = None - if getattr(args, "bench_parallel", False) and mode in { - 3, 6, 7, 9, 10, 11, 13, 15, 16, 17 - }: - splits = getattr(args, "num_splits", -1) - if splits is None or splits < 1: - splits = _auto_num_splits(eff_bs, pages_per_seg, topk_val) - # num_splits must divide pages_per_seg, and num_splits*topk_val - # must fit the merge cap (8192). Clamp + round. - if splits > 1 and pages_per_seg % splits != 0: - # snap down to the largest divisor ≤ splits - for cand in range(splits, 0, -1): - if pages_per_seg % cand == 0: - splits = cand - break - while splits * topk_val > 8192 and splits > 1: - splits //= 2 - # splits=1 means "no parallel" — the parallel kernel has no - # work to do and would ask for seq_len * 5 bytes of smem (the - # Phase-1 cache), blowing past the 96 KB ceiling at seq_len - # > ~19K. Skip the parallel timing row for this config. - if splits < 2: - parallel_ms = None - parallel_splits_used = None - row = { - "mode": mode, - "mode_name": MAPPING_MODE_NAMES.get(mode, f"m{mode}"), - "power": power, - "remap_ms": None, - "topk_after_remap_ms": None, - "split_total_ms": None, - "fused_ms": fused["mean_ms"], - "parallel_ms": parallel_ms, - "parallel_splits": parallel_splits_used, - "cluster_ms": cluster_ms, - **_collect_threshold_stats( - inputs, topk_val, pages_per_seg, args, mode, power - ), - } - config["modes"].append(row) - continue - chunk_per_split = pages_per_seg // splits - parallel_x = ( - inputs["x"].view(eff_bs, pages_per_seg) - .view(eff_bs, splits, chunk_per_split) - .contiguous() - ) - parallel_out = torch.empty(eff_bs, topk_val, - dtype=torch.int32, device="cuda") - parallel_args = ( - parallel_x, - parallel_out, - eff_bs, - splits, - chunk_per_split, - topk_val, - mode, - power, + if getattr(args, "bench_parallel", False) and mode in ARITHMETIC_MODES: + par_args = ( + inputs["x"], + inputs["dense_kv_indptr"], + inputs["sparse_kv_indptr"], + inputs["dense_kv_indices"], + inputs["sparse_kv_indices"], + eff_bs, topk_val, args.reserved_bos, args.reserved_eos, pages_per_seg, + mode, power, ) - parallel = bench_kernel( - fast_fused_topk_merge, parallel_args, args.warmup, args.repeat + inputs["sparse_kv_indices"].zero_() + par_bench = bench_kernel( + topk_output_adaptive, par_args, args.warmup, args.repeat ) - parallel_ms = parallel["mean_ms"] - parallel_splits_used = splits - - # Hopper TBC+DSMEM variant — same args, sm_90+ only, - # cluster cap = 8. Fresh output buffer so validation can - # compare against the parallel kernel's output independently. - cluster_ms = None - if splits <= 8 and torch.cuda.get_device_capability(0)[0] >= 9: - cluster_out = torch.empty(eff_bs, topk_val, - dtype=torch.int32, device="cuda") - cluster_args = ( - parallel_x, - cluster_out, - eff_bs, - splits, - chunk_per_split, - topk_val, - mode, - power, - ) - cluster = bench_kernel( - fast_cluster_topk_merge, cluster_args, args.warmup, args.repeat - ) - cluster_ms = cluster["mean_ms"] + parallel_ms = par_bench["mean_ms"] + parallel_splits_used = 2 # Split-phase timing is only meaningful for arithmetic modes. # MAPPING_LUT_CDF / QUANTILE / TRUNC8 apply their mapping inside @@ -1134,7 +1084,8 @@ def main(): p.add_argument("--remap-bench", action="store_true", help="Run the split-phase remap/topk/fused/baseline benchmark.") p.add_argument("--bench-parallel", action="store_true", - help="Also time fast_fused_topk_merge (single-kernel split+merge).") + help="Time the adaptive split-2 last-CTA-wins kernel " + "(topk_output_adaptive) and fill the parallel_ms column.") p.add_argument("--num-splits", type=int, default=-1, help="Partitions per batch for the parallel kernel. -1 = auto " "(sm_count / eff_batch_size, clamped to pages_per_seg/topk_val).") diff --git a/benchmarks/bench_topk_setting_sweep.py b/benchmarks/bench_topk_setting_sweep.py new file mode 100644 index 00000000..db587b74 --- /dev/null +++ b/benchmarks/bench_topk_setting_sweep.py @@ -0,0 +1,1211 @@ +#!/usr/bin/env python +"""Comprehensive (pages, K, batch, split, mapping, dtype) latency sweep +comparing the three TopK kernels in this repo: + + topk_sglang_merge.cu -> topk_output_adaptive_workspace (adaptive split path) + topk_sglang.cu -> topk_output_sglang_fused (fused two-stage radix) + topk.cu -> topk_output (CUB BlockRadixSort full sort) + +Outputs four files under : + topk_setting_sweep_raw.csv long-form, one row per measurement + topk_setting_sweep_best_adaptive.csv best adaptive split per (pages,K,B,mapping,dtype) + topk_parallel_advantage_summary.csv win/loss region rollup + topk_setting_sweep_report.md human-readable analysis + +See module-level docstring of topk_sglang_merge.cu for the dispatcher +contract this script mirrors when labeling actual_path. +""" +from __future__ import annotations + +import argparse +import csv +import math +import statistics +import sys +import time +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import torch +import vortex_torch_C as V + +# --------------------------------------------------------------------------- # +# Mapping mode constants — must match csrc/topk_mapping.cuh. +# --------------------------------------------------------------------------- # +MAPPING_NONE = 0 +MAPPING_POWER = 3 +MAPPING_LOG = 4 +MAPPING_ASINH = 6 +MAPPING_LOG1P = 7 +MAPPING_TRUNC8 = 8 +MAPPING_ERF = 9 +MAPPING_TANH = 10 + +MAPPING_NAMES = { + MAPPING_NONE: "NONE", + MAPPING_POWER: "POWER", + MAPPING_LOG: "LOG", + MAPPING_ASINH: "ASINH", + MAPPING_LOG1P: "LOG1P", + MAPPING_TRUNC8: "TRUNC8", + MAPPING_ERF: "ERF", + MAPPING_TANH: "TANH", +} +MAPPING_BY_NAME = {v: k for k, v in MAPPING_NAMES.items()} + +# Mirrors the dispatcher in csrc/topk_sglang_merge.cu. +K_MAX_ADAPTIVE = 32 # K <= 32 stays on the adaptive K=30 path +K_FUSED_FALLBACK = 1024 # K >= 1024 routes to fused, even from adaptive entry + +LOCAL_BLOCK_FULL_SORT = 0 +LOCAL_SELECT32_SORT32 = 1 + +# topk.cu template ladder caps at 8192 pages. +TOPK_CU_MAX_PAGES = 8192 + +# kCfg* capacity table from csrc/topk_sglang_merge.cu (BLOCK_FULL_SORT only). +BLOCK_FULL_SORT_CAPACITY = {1: 8192, 2: 8192, 4: 4096, 8: 4096, 16: 2048, 32: 1024} + +DEFAULT_PAGES = [4096, 8192, 16384, 32768, 65536] +DEFAULT_KS = [30, 64, 128, 256, 512, 1024, 2048] +DEFAULT_BATCHES = [1, 2, 4, 8, 16] +DEFAULT_SPLITS = [1, 2, 4, 8, 16, 32] +DEFAULT_MAPPING_NAMES = ["NONE"] +DEFAULT_DTYPES = ["bfloat16"] + +WIN_THRESHOLD = 1.03 # adaptive "wins" if speedup_vs_sglang >= this + +# Production merge mode wired into TopK30_RandomSplit_Select32_Kernel / +# TopK30_RandomSplit_Parallel_Kernel — cub::WarpMergeSort. The merge-only +# ablation sub-sweep timings are written separately to topk_merge_mode_summary.csv. +PROD_MERGE_NAME = "warp_cub" +LOCAL_MODE_NAMES = {LOCAL_BLOCK_FULL_SORT: "BLOCK_FULL_SORT", + LOCAL_SELECT32_SORT32: "SELECT32_SORT32"} + +# Ablation mode IDs from topk_adaptive_profile.cu. +ABL_LOCAL_WITH_WORKSPACE = 1 # populates partial workspace, no merge +ABL_MERGE_PROD_DEFAULT = 5 +ABL_MERGE_CUB_WARP = 6 +ABL_MERGE_CUB_BLOCK = 7 +ABL_MERGE_KWAY = 11 +MERGE_ABL_NAMES = { + ABL_MERGE_PROD_DEFAULT: "prod_default(legacy)", + ABL_MERGE_CUB_WARP: "warp_cub", + ABL_MERGE_CUB_BLOCK: "block_cub", + ABL_MERGE_KWAY: "kway", +} + +# --------------------------------------------------------------------------- # + +def _dtype_str_to_torch(s: str) -> torch.dtype: + return {"bfloat16": torch.bfloat16, "float": torch.float32, "float32": torch.float32}[s] + + +def apply_remap_torch(x: torch.Tensor, mode: int, p: float) -> torch.Tensor: + """Reference-side remap, kept in sync with apply_transform_tmpl in topk_mapping.cuh. + + Only modes used by this sweep are implemented. Adding more requires editing + topk_mapping.cuh and propagating to this function. + """ + if mode in (MAPPING_NONE, MAPPING_TRUNC8): + return x + if mode == MAPPING_POWER: + return torch.copysign(torch.abs(x).pow(p), x) + if mode == MAPPING_LOG: + return torch.copysign(torch.log(torch.abs(x) + 1.0), x) + if mode == MAPPING_ASINH: + return torch.asinh(p * x) + if mode == MAPPING_LOG1P: + return torch.copysign(torch.log1p(p * torch.abs(x)), x) + if mode == MAPPING_ERF: + return torch.erf(p * x) + if mode == MAPPING_TANH: + return torch.tanh(p * x) + raise ValueError(f"reference remap not implemented for mapping_mode={mode}") + + +# --------------------------------------------------------------------------- # +# Tensor / workspace setup. +# --------------------------------------------------------------------------- # +@dataclass +class Inputs: + scores: torch.Tensor + dense_kv_indptr: torch.Tensor + sparse_kv_indptr: torch.Tensor + dense_kv_indices: torch.Tensor + out: torch.Tensor # int32 sparse_kv_indices + B: int + pages: int + K: int + reserved_bos: int + reserved_eos: int + + +def make_inputs(B: int, pages: int, K: int, dtype: torch.dtype, + reserved_bos: int = 1, reserved_eos: int = 2, + seed: int = 0) -> Inputs: + torch.manual_seed(seed) + device = torch.device("cuda") + dense_kv_indptr = torch.arange(B + 1, device=device, dtype=torch.int32) * pages + sparse_kv_indptr = torch.arange(B + 1, device=device, dtype=torch.int32) * (K + reserved_bos + reserved_eos) + total = B * pages + scores = torch.randn(total, device=device, dtype=dtype) + dense_kv_indices = torch.arange(total, device=device, dtype=torch.int32) + out = torch.full((B * (K + reserved_bos + reserved_eos),), -1, device=device, dtype=torch.int32) + return Inputs(scores, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, out, + B, pages, K, reserved_bos, reserved_eos) + + +def make_workspace(B_max: int, max_split: int = 32, K_local: int = 32): + device = torch.device("cuda") + ws_elems = max(B_max * max_split * K_local, 64) + return dict( + partial_keys = torch.zeros(ws_elems, device=device, dtype=torch.int32), + partial_indices = torch.zeros(ws_elems, device=device, dtype=torch.int32), + done_counter = torch.zeros(max(B_max, 1), device=device, dtype=torch.int32), + ) + + +# --------------------------------------------------------------------------- # +# Reference top-K and correctness. +# --------------------------------------------------------------------------- # +def reference_topk(inp: Inputs, mapping_mode: int, mapping_power: float): + """Returns (ref_sets[B], ref_remapped[B] (cpu fp32), threshold_per_row[B]).""" + ref_sets = [] + ref_remapped = [] + thresholds = [] + for b in range(inp.B): + row = inp.scores[b * inp.pages + inp.reserved_bos + : (b + 1) * inp.pages - inp.reserved_eos].float() + remapped = apply_remap_torch(row, mapping_mode, mapping_power) + vals, idx_within = torch.topk(remapped, inp.K) + global_idx = (idx_within + b * inp.pages + inp.reserved_bos).cpu().tolist() + ref_sets.append(set(global_idx)) + ref_remapped.append(remapped.cpu()) + thresholds.append(vals.min().item()) + return ref_sets, ref_remapped, thresholds + + +def check_correctness(inp: Inputs, ref_sets, ref_remapped, thresholds, + mapping_mode: int, mapping_power: float) -> Tuple[bool, str]: + """Set equality with tie tolerance. + + Returns (ok, note). On failure, `note` describes the failure. + """ + out = inp.out.cpu() + for b in range(inp.B): + slot_start = b * (inp.K + inp.reserved_bos + inp.reserved_eos) + inp.reserved_bos + out_row = out[slot_start : slot_start + inp.K].tolist() + out_set = set(out_row) + if -1 in out_set: + return False, f"row {b}: -1 in output (count={out_row.count(-1)})" + if out_set == ref_sets[b]: + continue + # Tie tolerance: every kernel-selected score must reach the threshold, + # within fp tolerance. + row_offset = b * inp.pages + out_within = [g - row_offset - inp.reserved_bos for g in out_set] + npages_eff = inp.pages - inp.reserved_bos - inp.reserved_eos + if any(i < 0 or i >= npages_eff for i in out_within): + return False, f"row {b}: out-of-range global idx" + out_scores = ref_remapped[b][out_within] + thresh = thresholds[b] + tol = max(1e-6, 1e-3 * abs(thresh)) + min_out = out_scores.min().item() + if min_out < thresh - tol: + return False, (f"row {b}: min selected score={min_out:.4f} < " + f"K-th ref score={thresh:.4f} (tol={tol:.2e})") + return True, "" + + +# --------------------------------------------------------------------------- # +# Timing. +# --------------------------------------------------------------------------- # +def time_kernel_us(fn, warmup: int, repeat: int) -> Optional[Dict[str, float]]: + """Per-call event timing. Returns dict with mean/p50/p90/min/max/std (us) + or None if the kernel raised.""" + try: + for _ in range(warmup): + fn() + torch.cuda.synchronize() + except Exception: + return None + starts = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + try: + for i in range(repeat): + starts[i].record() + fn() + ends[i].record() + torch.cuda.synchronize() + except Exception: + return None + times_us = sorted(starts[i].elapsed_time(ends[i]) * 1000.0 for i in range(repeat)) + n = len(times_us) + mean = sum(times_us) / n + var = sum((t - mean) ** 2 for t in times_us) / n + return dict( + mean=mean, + p50=times_us[n // 2], + p90=times_us[min(n - 1, int(round(n * 0.9)))], + min=times_us[0], + max=times_us[-1], + std=math.sqrt(var), + ) + + +# --------------------------------------------------------------------------- # +# Method launchers. +# --------------------------------------------------------------------------- # +def call_fused(inp: Inputs, mapping_mode: int, mapping_power: float): + # Caller is responsible for inp.out.fill_(-1) BEFORE the timed loop if it + # cares about a clean baseline; the fill is its own kernel launch and would + # otherwise pollute kernel-only timing. + V.topk_output_sglang_fused( + inp.scores, inp.dense_kv_indptr, inp.sparse_kv_indptr, inp.dense_kv_indices, + inp.out, inp.B, inp.K, inp.reserved_bos, inp.reserved_eos, inp.pages, + mapping_mode, mapping_power, None, None, + ) + + +def call_topk_cu(inp: Inputs): + # NOTE: arg order differs from sglang variants - + # (x, dense_kv_indptr, dense_kv_indices, sparse_kv_indptr, ...) + V.topk_output( + inp.scores, inp.dense_kv_indptr, inp.dense_kv_indices, inp.sparse_kv_indptr, + inp.out, inp.B, inp.K, inp.reserved_bos, inp.reserved_eos, inp.pages, + ) + + +def call_adaptive(inp: Inputs, ws: dict, mapping_mode: int, mapping_power: float, + forced_split: int, forced_partition: int, local_mode: int): + V.topk_output_adaptive_workspace( + inp.scores, inp.dense_kv_indptr, inp.sparse_kv_indptr, inp.dense_kv_indices, + inp.out, ws["partial_keys"], ws["partial_indices"], ws["done_counter"], + inp.B, inp.K, inp.reserved_bos, inp.reserved_eos, inp.pages, + mapping_mode, mapping_power, + forced_split, forced_partition, local_mode, + ) + + +# --------------------------------------------------------------------------- # +# actual_path classification — mirrors the C++ dispatcher contract. +# --------------------------------------------------------------------------- # +def classify_adaptive_actual_path(K: int, split: int, pages: int, + local_mode: int) -> Tuple[str, Optional[int], bool]: + """Returns (actual_path, actual_split, is_supported). + + actual_split is the effective split count (None for fused fallback). + is_supported is False when the call would TORCH_CHECK fail. + """ + if K >= K_FUSED_FALLBACK: + return ("fused_fallback_large_k", None, True) + if K > K_MAX_ADAPTIVE: + return ("fused_fallback_mid_k", None, True) + if local_mode == LOCAL_BLOCK_FULL_SORT: + chunk_max = (pages + split - 1) // split + cap = BLOCK_FULL_SORT_CAPACITY.get(split, 0) + if cap < chunk_max: + return ("unsupported_capacity", split, False) + return ("adaptive_block_full_sort", split, True) + return ("adaptive_select32_sort32", split, True) + + +# --------------------------------------------------------------------------- # +# Sweep driver. +# --------------------------------------------------------------------------- # +@dataclass +class Row: + device_name: str + sm_count: int + dtype: str + mapping_mode: int + mapping_name: str + mapping_power: float + pages: int + topk: int + batch: int + method: str # "topk_sglang_fused", "topk_cu", "adaptive_merge" + requested_split: Optional[int] + actual_split: Optional[int] + local_mode: str # "BLOCK_FULL_SORT" / "SELECT32_SORT32" / "n/a" + merge_mode: str # "warp_cub" (production); "n/a" if no merge + candidate_count: Optional[int] # split * local_k; None if no merge + actual_path: str + mean_us: Optional[float] + p50_us: Optional[float] + p90_us: Optional[float] + min_us: Optional[float] + max_us: Optional[float] + std_us: Optional[float] + correctness: Optional[bool] + speedup_vs_sglang_fused: Optional[float] + speedup_vs_topk_cu: Optional[float] + notes: str + + +def run_one_setting(pages: int, K: int, B: int, mapping_mode: int, mapping_power: float, + dtype_str: str, splits: List[int], local_mode: int, + warmup: int, repeat: int, ws: dict, + device_name: str, sm_count: int) -> List[Row]: + """Bench every method at one (pages, K, B, mapping, dtype) cell.""" + rows: List[Row] = [] + inp = make_inputs(B, pages, K, _dtype_str_to_torch(dtype_str)) + ref_sets, ref_remapped, thresholds = reference_topk(inp, mapping_mode, mapping_power) + map_name = MAPPING_NAMES.get(mapping_mode, str(mapping_mode)) + + def _row(**kw): + defaults = dict( + device_name=device_name, sm_count=sm_count, dtype=dtype_str, + mapping_mode=mapping_mode, mapping_name=map_name, mapping_power=mapping_power, + pages=pages, topk=K, batch=B, + requested_split=None, actual_split=None, + local_mode="n/a", merge_mode="n/a", candidate_count=None, + mean_us=None, p50_us=None, p90_us=None, min_us=None, max_us=None, std_us=None, + correctness=None, speedup_vs_sglang_fused=None, speedup_vs_topk_cu=None, + notes="", + ) + defaults.update(kw) + return Row(**defaults) + + # ---------- topk_sglang.cu fused baseline ---------- + fused_us: Optional[float] = None + try: + inp.out.fill_(-1) + call_fused(inp, mapping_mode, mapping_power) + torch.cuda.synchronize() + ok, note = check_correctness(inp, ref_sets, ref_remapped, thresholds, + mapping_mode, mapping_power) + except RuntimeError as e: + # Most common: pages > fused dynamic-smem ceiling (~96KB → ~96k pages). + msg = str(e) + path = "fused_unavailable_smem" if "exceeds" in msg or "smem" in msg.lower() else "error" + rows.append(_row(method="topk_sglang_fused", actual_path=path, + correctness=False, notes=f"raised: {msg[:160]}")) + except Exception as e: + rows.append(_row(method="topk_sglang_fused", actual_path="error", + correctness=False, notes=f"raised: {e}")) + else: + t = time_kernel_us(lambda: call_fused(inp, mapping_mode, mapping_power), + warmup, repeat) + if t is None: + rows.append(_row(method="topk_sglang_fused", actual_path="error", + correctness=ok, notes="time_kernel_us returned None")) + else: + fused_us = t["mean"] + rows.append(_row( + method="topk_sglang_fused", actual_path="fused", + mean_us=t["mean"], p50_us=t["p50"], p90_us=t["p90"], + min_us=t["min"], max_us=t["max"], std_us=t["std"], + correctness=ok, + speedup_vs_sglang_fused=1.0, + notes=note, + )) + + # ---------- topk.cu baseline (CUB full sort) ---------- + cub_us: Optional[float] = None + if pages > TOPK_CU_MAX_PAGES: + rows.append(_row(method="topk_cu", actual_path="topk_cu_unsupported", + notes=f"pages={pages} > template ladder cap {TOPK_CU_MAX_PAGES}")) + else: + try: + inp.out.fill_(-1) + call_topk_cu(inp) + torch.cuda.synchronize() + # topk.cu doesn't apply remap, so its output is for raw scores — + # ALWAYS check against the unmapped reference for fairness. + if mapping_mode in (MAPPING_NONE, MAPPING_TRUNC8): + ok, note = check_correctness(inp, ref_sets, ref_remapped, thresholds, + mapping_mode, mapping_power) + else: + ok, note = True, "remap unsupported by topk.cu; correctness skipped" + except Exception as e: + rows.append(_row(method="topk_cu", actual_path="error", + correctness=False, notes=f"raised: {e}")) + else: + t = time_kernel_us(lambda: call_topk_cu(inp), warmup, repeat) + if t is None: + rows.append(_row(method="topk_cu", actual_path="error", + correctness=ok, notes="time_kernel_us returned None")) + else: + cub_us = t["mean"] + rows.append(_row( + method="topk_cu", actual_path="cub_full_sort", + mean_us=t["mean"], p50_us=t["p50"], p90_us=t["p90"], + min_us=t["min"], max_us=t["max"], std_us=t["std"], + correctness=ok, + speedup_vs_sglang_fused=(fused_us / t["mean"]) if fused_us else None, + speedup_vs_topk_cu=1.0, + notes=note, + )) + + # ---------- topk_sglang_merge.cu adaptive (one row per requested split) ---------- + local_mode_str = LOCAL_MODE_NAMES.get(local_mode, "unknown") + for split in splits: + actual_path, actual_split, is_supported = classify_adaptive_actual_path( + K, split, pages, local_mode) + # Adaptive paths use cub::WarpMergeSort over (split * 32) candidates; + # split=1 has no merge stage at all. + on_adaptive_path = actual_path.startswith("adaptive_") + merge_mode = PROD_MERGE_NAME if (on_adaptive_path and split > 1) else "n/a" + candidate_count = (split * 32) if (on_adaptive_path and split > 1) else None + local_mode_for_row = local_mode_str if on_adaptive_path else "n/a" + + if not is_supported: + rows.append(_row( + method="adaptive_merge", requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, actual_path=actual_path, + notes=(f"BLOCK_FULL_SORT cap={BLOCK_FULL_SORT_CAPACITY.get(split,0)} " + f"< chunk_max={(pages + split - 1)//split}"), + )) + continue + + try: + inp.out.fill_(-1) + call_adaptive(inp, ws, mapping_mode, mapping_power, + forced_split=split, forced_partition=1, # CONTIGUOUS + local_mode=local_mode) + torch.cuda.synchronize() + ok, note = check_correctness(inp, ref_sets, ref_remapped, thresholds, + mapping_mode, mapping_power) + except RuntimeError as e: + # K>32 and pages too large for fused fallback's smem. + msg = str(e) + err_path = ("fused_fallback_unavailable_smem" + if (not on_adaptive_path and ("exceeds" in msg or "smem" in msg.lower())) + else actual_path + "_error") + rows.append(_row(method="adaptive_merge", + requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, + actual_path=err_path, + correctness=False, notes=f"raised: {msg[:160]}")) + continue + except Exception as e: + rows.append(_row(method="adaptive_merge", + requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, actual_path=actual_path, + correctness=False, notes=f"raised: {e}")) + continue + + # For fused-fallback paths, all forced_split values produce identical + # timings (same fused kernel called); we still time each entry to + # quantify dispatcher overhead. + t = time_kernel_us( + lambda: call_adaptive(inp, ws, mapping_mode, mapping_power, + forced_split=split, forced_partition=1, + local_mode=local_mode), + warmup, repeat, + ) + if t is None: + rows.append(_row(method="adaptive_merge", + requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, actual_path=actual_path, + correctness=ok, notes="time_kernel_us returned None")) + continue + + rows.append(_row( + method="adaptive_merge", + requested_split=split, actual_split=actual_split, + local_mode=local_mode_for_row, merge_mode=merge_mode, + candidate_count=candidate_count, actual_path=actual_path, + mean_us=t["mean"], p50_us=t["p50"], p90_us=t["p90"], + min_us=t["min"], max_us=t["max"], std_us=t["std"], + correctness=ok, + speedup_vs_sglang_fused=(fused_us / t["mean"]) if fused_us else None, + speedup_vs_topk_cu=(cub_us / t["mean"]) if cub_us else None, + notes=note, + )) + + return rows + + +# --------------------------------------------------------------------------- # +# Merge-mode ablation (K=30 only — the ablation kernels in +# topk_adaptive_profile.cu are hardcoded to kLocalK_Top30 = 32). +# --------------------------------------------------------------------------- # +def call_ablation(inp: Inputs, ws: dict, scratch: torch.Tensor, + ablation_mode: int, forced_split: int): + V.topk_output_adaptive_workspace_ablation( + inp.scores, inp.dense_kv_indptr, inp.sparse_kv_indptr, inp.dense_kv_indices, + inp.out, ws["partial_keys"], ws["partial_indices"], ws["done_counter"], + scratch, + inp.B, inp.K, inp.reserved_bos, inp.reserved_eos, inp.pages, + ablation_mode, forced_split, + ) + + +def run_merge_ablation(pages_list, batches, splits, warmup, repeat, + ws, device_name, sm_count) -> List[dict]: + """Per (pages, B, split, merge_mode), measure merge-only latency. + + Workflow per cell: + 1. Populate the workspace via ablation_mode = LocalWithWorkspace (mode 1). + 2. For each merge variant, time merge-only kernel (modes 5/6/7/11). + + Returns list of dicts (CSV-ready).""" + rows = [] + device = torch.device("cuda") + scratch = torch.zeros(max(1, max(batches) * max(splits)), + device=device, dtype=torch.int32) + K = 30 # ablation harness is K<=32 only + for pages in pages_list: + for B in batches: + inp = make_inputs(B, pages, K, torch.bfloat16) + for split in splits: + if split <= 1: + continue # nothing to merge + # Step 1: populate the workspace (ablation_mode=1). + try: + call_ablation(inp, ws, scratch, + ABL_LOCAL_WITH_WORKSPACE, split) + torch.cuda.synchronize() + except Exception as e: + rows.append(dict(pages=pages, batch=B, split=split, + merge_mode="setup_failed", + mean_us=None, notes=f"populate raised: {e}")) + continue + # Step 2: merge variants. Some require specific splits. + variants = [ABL_MERGE_PROD_DEFAULT, ABL_MERGE_CUB_WARP, + ABL_MERGE_CUB_BLOCK, ABL_MERGE_KWAY] + for ablv in variants: + name = MERGE_ABL_NAMES[ablv] + try: + call_ablation(inp, ws, scratch, ablv, split) + torch.cuda.synchronize() + except Exception as e: + rows.append(dict(pages=pages, batch=B, split=split, + merge_mode=name, candidate_count=split * 32, + mean_us=None, + notes=f"raised: {repr(e)[:120]}")) + continue + t = time_kernel_us( + lambda av=ablv, sp=split: call_ablation(inp, ws, scratch, av, sp), + warmup, repeat, + ) + if t is None: + rows.append(dict(pages=pages, batch=B, split=split, + merge_mode=name, candidate_count=split * 32, + mean_us=None, notes="time_kernel_us failed")) + continue + rows.append(dict( + device_name=device_name, sm_count=sm_count, + pages=pages, batch=B, split=split, + merge_mode=name, candidate_count=split * 32, + mean_us=t["mean"], p50_us=t["p50"], p90_us=t["p90"], + min_us=t["min"], max_us=t["max"], std_us=t["std"], + notes="", + )) + return rows + + +def write_merge_mode_csv(merge_rows: List[dict], path: Path): + if not merge_rows: + with path.open("w") as f: + f.write("# merge ablation skipped (use --merge-ablation to enable)\n") + return + # Pivot to wide form: one row per (pages, batch, split) with columns per merge mode. + by_key = {} + for r in merge_rows: + key = (r["pages"], r["batch"], r["split"]) + by_key.setdefault(key, {"candidate_count": r.get("candidate_count")}) + by_key[key][r["merge_mode"]] = r.get("mean_us") + cols = ["pages", "batch", "split", "candidate_count", + "warp_cub_us", "block_cub_us", "kway_us", "prod_default_us", + "best_merge_mode", "best_merge_us", "speedup_best_vs_warp"] + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for key in sorted(by_key): + pages, B, split = key + d = by_key[key] + warp = d.get("warp_cub") + block = d.get("block_cub") + kway = d.get("kway") + prod = d.get("prod_default(legacy)") + choices = [(name, t) for name, t in + (("warp_cub", warp), ("block_cub", block), + ("kway", kway), ("prod_default", prod)) + if t is not None] + if choices: + best_name, best_us = min(choices, key=lambda x: x[1]) + sp = (warp / best_us) if (warp and best_us) else None + else: + best_name, best_us, sp = "n/a", None, None + w.writerow([pages, B, split, d.get("candidate_count"), + f"{warp:.3f}" if warp else "", + f"{block:.3f}" if block else "", + f"{kway:.3f}" if kway else "", + f"{prod:.3f}" if prod else "", + best_name, + f"{best_us:.3f}" if best_us else "", + f"{sp:.3f}" if sp else ""]) + print(f"wrote {path}") + + +# --------------------------------------------------------------------------- # +# Adversarial correctness — additional unit-test-style cases. +# --------------------------------------------------------------------------- # +def adversarial_correctness_test(local_mode: int) -> List[dict]: + """Return list of dicts describing each adversarial case + per-method outcome.""" + device = torch.device("cuda") + K, RES_BOS, RES_EOS, B, PAGES = 30, 1, 2, 2, 4096 + cases = [] + + def build_scores(kind: str, dtype) -> torch.Tensor: + n = B * PAGES + if kind == "all_equal": + return torch.full((n,), 1.5, device=device, dtype=dtype) + if kind == "tie_heavy_high8": + x = torch.randn(n, device=device, dtype=torch.float32) + mask = torch.rand(n, device=device) < 0.05 + x[mask] = 100.0 # identical large values - ties for top-K + return x.to(dtype) + if kind == "mixed_sign": + x = torch.randn(n, device=device, dtype=torch.float32) * 10 + return x.to(dtype) + if kind == "threshold_overflow": + x = torch.zeros(n, device=device, dtype=torch.float32) + mask = torch.rand(n, device=device) < 0.10 + x[mask] = 1.0 # > K items at one bin + return x.to(dtype) + raise ValueError(kind) + + ws = make_workspace(B, max_split=32, K_local=32) + for kind in ("all_equal", "tie_heavy_high8", "mixed_sign", "threshold_overflow"): + scores = build_scores(kind, torch.bfloat16) + inp = make_inputs(B, PAGES, K, torch.bfloat16) + inp.scores.copy_(scores) + ref_sets, ref_remapped, thresholds = reference_topk(inp, MAPPING_NONE, 0.5) + # Fused + try: + inp.out.fill_(-1) + call_fused(inp, MAPPING_NONE, 0.5); torch.cuda.synchronize() + ok_f, note_f = check_correctness(inp, ref_sets, ref_remapped, thresholds, MAPPING_NONE, 0.5) + except Exception as e: + ok_f, note_f = False, f"raised: {e}" + # Adaptive split=1 (production path). + try: + inp.out.fill_(-1) + call_adaptive(inp, ws, MAPPING_NONE, 0.5, + forced_split=1, forced_partition=1, local_mode=local_mode) + torch.cuda.synchronize() + ok_a1, note_a1 = check_correctness(inp, ref_sets, ref_remapped, thresholds, MAPPING_NONE, 0.5) + except Exception as e: + ok_a1, note_a1 = False, f"raised: {e}" + # Adaptive split=4 (merge path). + try: + inp.out.fill_(-1) + call_adaptive(inp, ws, MAPPING_NONE, 0.5, + forced_split=4, forced_partition=1, local_mode=local_mode) + torch.cuda.synchronize() + ok_a4, note_a4 = check_correctness(inp, ref_sets, ref_remapped, thresholds, MAPPING_NONE, 0.5) + except Exception as e: + ok_a4, note_a4 = False, f"raised: {e}" + cases.append(dict(case=kind, fused_ok=ok_f, fused_note=note_f, + adapt_split1_ok=ok_a1, adapt_split1_note=note_a1, + adapt_split4_ok=ok_a4, adapt_split4_note=note_a4)) + return cases + + +# --------------------------------------------------------------------------- # +# CSV / report writers. +# --------------------------------------------------------------------------- # +RAW_COLUMNS = [ + "device_name", "sm_count", "dtype", "mapping_mode", "mapping_name", "mapping_power", + "pages", "topk", "batch", "method", + "requested_split", "actual_split", + "local_mode", "merge_mode", "candidate_count", + "actual_path", + "mean_us", "p50_us", "p90_us", "min_us", "max_us", "std_us", + "correctness", "speedup_vs_sglang_fused", "speedup_vs_topk_cu", "notes", +] + + +def write_raw_csv(rows: List[Row], path: Path): + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(RAW_COLUMNS) + for r in rows: + d = asdict(r) + w.writerow([d[c] for c in RAW_COLUMNS]) + print(f"wrote {path} ({len(rows)} rows)") + + +def write_best_adaptive_csv(rows: List[Row], path: Path): + """One row per (pages,K,B,mapping,dtype). Choose best adaptive split among + rows whose actual_path starts with 'adaptive_'. Also emit fused/cub for ref.""" + cols = [ + "pages", "topk", "batch", "mapping_name", "dtype", + "best_adaptive_split", "best_adaptive_local_mode", "best_adaptive_merge_mode", + "best_adaptive_latency_us", "best_adaptive_actual_path", + "sglang_fused_latency_us", "topk_cu_latency_us", + "speedup_best_adaptive_vs_sglang", "speedup_best_adaptive_vs_topk_cu", + "adaptive_wins_vs_sglang", "adaptive_wins_vs_topk_cu", + ] + by_key: Dict[Tuple, Dict[str, object]] = {} + for r in rows: + key = (r.pages, r.topk, r.batch, r.mapping_name, r.dtype) + rec = by_key.setdefault(key, dict(adaptive=[], fused_us=None, cub_us=None)) + if r.method == "topk_sglang_fused" and r.correctness and r.mean_us is not None: + rec["fused_us"] = r.mean_us + elif r.method == "topk_cu" and r.correctness and r.mean_us is not None: + rec["cub_us"] = r.mean_us + elif (r.method == "adaptive_merge" and r.correctness + and r.actual_path.startswith("adaptive_") + and r.mean_us is not None): + rec["adaptive"].append((r.mean_us, r.requested_split, r.local_mode, + r.merge_mode, r.actual_path)) + + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for key, rec in sorted(by_key.items()): + pages, K, B, mapping_name, dtype = key + best = min(rec["adaptive"], default=None) + if best is None: + best_us, best_split, best_local, best_merge, best_path = None, None, "n/a", "n/a", "n/a" + else: + best_us, best_split, best_local, best_merge, best_path = best + fused_us = rec["fused_us"] + cub_us = rec["cub_us"] + sp_f = (fused_us / best_us) if (best_us and fused_us) else None + sp_c = (cub_us / best_us) if (best_us and cub_us) else None + wins_f = (sp_f is not None and sp_f >= WIN_THRESHOLD) + wins_c = (sp_c is not None and sp_c >= WIN_THRESHOLD) + w.writerow([ + pages, K, B, mapping_name, dtype, + best_split, best_local, best_merge, + f"{best_us:.3f}" if best_us is not None else "", + best_path, + f"{fused_us:.3f}" if fused_us is not None else "", + f"{cub_us:.3f}" if cub_us is not None else "", + f"{sp_f:.3f}" if sp_f is not None else "", + f"{sp_c:.3f}" if sp_c is not None else "", + wins_f, wins_c, + ]) + print(f"wrote {path}") + + +def k_bucket(K: int) -> str: + if K <= 32: return "small_K(<=32)" + if K <= 512: return "mid_K(64-512)" + return "large_K(>=1024)" + + +def write_advantage_summary_csv(rows: List[Row], path: Path): + """Group by (k_bucket, pages, batch). Count adaptive wins, mean/best speedup, + best split distribution, common actual_path.""" + by_key: Dict[Tuple, Dict[str, list]] = {} + # Build per-cell best_adaptive entries (one per setting). + setting_best: Dict[Tuple, Dict] = {} + for r in rows: + if not (r.method == "adaptive_merge" and r.correctness and r.mean_us is not None): + continue + if not r.actual_path.startswith("adaptive_"): + continue + key = (r.pages, r.topk, r.batch, r.mapping_name, r.dtype) + rec = setting_best.setdefault(key, {"best_us": float("inf"), "split": None, "path": None}) + if r.mean_us < rec["best_us"]: + rec["best_us"] = r.mean_us; rec["split"] = r.requested_split; rec["path"] = r.actual_path + fused_lookup = {(r.pages, r.topk, r.batch, r.mapping_name, r.dtype): r.mean_us + for r in rows + if r.method == "topk_sglang_fused" and r.correctness and r.mean_us is not None} + + for setting, best in setting_best.items(): + pages, K, B, mapping_name, dtype = setting + bucket = k_bucket(K) + gkey = (bucket, pages, B, mapping_name, dtype) + agg = by_key.setdefault(gkey, dict(speedups=[], splits=[], paths=[], total=0, wins=0)) + agg["total"] += 1 + fused_us = fused_lookup.get(setting) + if fused_us: + sp = fused_us / best["best_us"] + agg["speedups"].append(sp) + if sp >= WIN_THRESHOLD: agg["wins"] += 1 + agg["splits"].append(best["split"]) + agg["paths"].append(best["path"]) + + cols = ["k_bucket", "pages", "batch", "mapping_name", "dtype", + "n_settings", "n_adaptive_wins", "win_rate", + "best_speedup", "mean_speedup", "median_speedup", + "best_split_mode", "best_split_distribution", "common_actual_path"] + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for gkey, agg in sorted(by_key.items()): + bucket, pages, B, mapping_name, dtype = gkey + sps = agg["speedups"] + splits = [s for s in agg["splits"] if s is not None] + mode_split = (statistics.mode(splits) if splits else None) + split_dist = "|".join(f"{s}:{splits.count(s)}" for s in sorted(set(splits))) if splits else "" + paths = [p for p in agg["paths"] if p] + common_path = statistics.mode(paths) if paths else "" + w.writerow([ + bucket, pages, B, mapping_name, dtype, + agg["total"], agg["wins"], + f"{agg['wins']/agg['total']:.2f}" if agg["total"] else "", + f"{max(sps):.3f}" if sps else "", + f"{statistics.mean(sps):.3f}" if sps else "", + f"{statistics.median(sps):.3f}" if sps else "", + mode_split, split_dist, common_path, + ]) + print(f"wrote {path}") + + +def _fmt(v, nd=2, default=" -"): + if v is None: return default + return f"{v:>{nd+5}.{nd}f}" + + +def print_per_K_tables(rows: List[Row], splits: List[int], mapping_filter: Optional[str] = None): + """Compact per-K terminal tables.""" + keys = sorted({(r.topk, r.mapping_name, r.dtype, r.pages, r.batch) for r in rows}) + by_k_map = {} + for r in rows: + if mapping_filter is not None and r.mapping_name != mapping_filter: + continue + by_k_map.setdefault((r.topk, r.mapping_name, r.dtype), []).append(r) + for (K, mapping_name, dtype), kr in sorted(by_k_map.items()): + print() + print(f"=== K={K} mapping={mapping_name} dtype={dtype} ===") + # Column header + hdr = (f"{'pages':>6} {'B':>3} {'fused_us':>9} {'cub_us':>8} " + + " ".join(f"{'s='+str(s):>9}" for s in splits) + + f" {'best_us':>8} {'split':>5} {'sp_vs_fused':>11}") + print(hdr) + print("-" * len(hdr)) + cells = {} + for r in kr: + cells.setdefault((r.pages, r.batch), {}) + kk = (r.pages, r.batch) + if r.method == "topk_sglang_fused": + cells[kk]["fused"] = r.mean_us + elif r.method == "topk_cu": + cells[kk]["cub"] = r.mean_us + elif r.method == "adaptive_merge": + cells[kk].setdefault("adapt", {})[r.requested_split] = (r.mean_us, r.actual_path) + for (pages, B), c in sorted(cells.items()): + adapt = c.get("adapt", {}) + adapt_us = {s: (adapt.get(s, (None, ""))[0]) for s in splits} + valid = [(s, adapt[s][0]) for s in splits if s in adapt + and adapt[s][0] is not None + and adapt[s][1].startswith("adaptive_")] + if valid: + best_split, best_us = min(valid, key=lambda kv: kv[1]) + else: + best_split, best_us = None, None + fused_us = c.get("fused") + sp = (fused_us / best_us) if (fused_us and best_us) else None + print( + f"{pages:>6d} {B:>3d} {_fmt(fused_us):>9} {_fmt(c.get('cub')):>8} " + + " ".join(f"{_fmt(adapt_us[s]):>9}" for s in splits) + + f" {_fmt(best_us):>8} {str(best_split) if best_split else '-':>5} " + + (f"{sp:>10.3f}x" if sp else f"{'-':>11}") + ) + + +def write_markdown_report(rows: List[Row], device_info: dict, args, path: Path, + splits: List[int], best_csv: Path, advantage_csv: Path, + raw_csv: Path, merge_csv: Optional[Path] = None, + merge_rows: Optional[List[dict]] = None, + adversarial_csv: Optional[Path] = None, + adversarial_rows: Optional[List[dict]] = None): + failed = [r for r in rows if r.correctness is False] + n_total = len(rows) + with path.open("w") as f: + f.write(f"# TopK Setting Sweep Report\n\n") + f.write(f"- Device: **{device_info['name']}** (SMs: {device_info['sm_count']})\n") + f.write(f"- torch: {torch.__version__} CUDA: {torch.version.cuda}\n") + f.write(f"- Pages: {args.pages}\n") + f.write(f"- K: {args.ks}\n") + f.write(f"- Batches: {args.batches}\n") + f.write(f"- Adaptive splits: {splits}\n") + f.write(f"- Mappings: {args.mappings}\n") + f.write(f"- Local mode: {'BLOCK_FULL_SORT' if args.local_mode == LOCAL_BLOCK_FULL_SORT else 'SELECT32_SORT32'}\n") + f.write(f"- warmup={args.warmup}, repeat={args.repeat}\n") + f.write(f"- Total measurements: {n_total}, correctness failures: {len(failed)}\n\n") + + # Per-K compact tables. + f.write("## Per-K latency tables (us)\n\n") + by_k = {} + for r in rows: + by_k.setdefault((r.topk, r.mapping_name, r.dtype), []).append(r) + for (K, mapping_name, dtype), kr in sorted(by_k.items()): + f.write(f"### K={K}, mapping={mapping_name}, dtype={dtype}\n\n") + cells = {} + for r in kr: + kk = (r.pages, r.batch) + cells.setdefault(kk, {}) + if r.method == "topk_sglang_fused": cells[kk]["fused"] = r.mean_us + elif r.method == "topk_cu": cells[kk]["cub"] = r.mean_us + elif r.method == "adaptive_merge": + cells[kk].setdefault("adapt", {})[r.requested_split] = (r.mean_us, r.actual_path) + head = (["pages", "B", "fused_us", "cub_us"] + + [f"adapt_s{s}_us" for s in splits] + + ["best_us", "best_split", "actual_path", "speedup_vs_fused"]) + f.write("| " + " | ".join(head) + " |\n") + f.write("|" + "|".join("---:" for _ in head) + "|\n") + for (pages, B), c in sorted(cells.items()): + adapt = c.get("adapt", {}) + row = [str(pages), str(B), + f"{c.get('fused'):.2f}" if c.get('fused') else "-", + f"{c.get('cub'):.2f}" if c.get('cub') else "-"] + for s in splits: + val = adapt.get(s, (None, ""))[0] + row.append(f"{val:.2f}" if val else "-") + valid = [(s, adapt[s][0], adapt[s][1]) for s in splits + if s in adapt and adapt[s][0] is not None + and adapt[s][1].startswith("adaptive_")] + if valid: + best_split, best_us, best_path = min(valid, key=lambda x: x[1]) + else: + best_split, best_us, best_path = None, None, "-" + fused_us = c.get('fused') + sp = (fused_us / best_us) if (fused_us and best_us) else None + # If everything was fused-fallback, fall back to noting that. + if best_us is None: + fb = next((v for v in adapt.values() if v[0] is not None), (None, "-")) + row += ["-", "-", fb[1], "-"] + else: + row += [f"{best_us:.2f}", str(best_split), best_path, + f"{sp:.3f}x" if sp else "-"] + f.write("| " + " | ".join(row) + " |\n") + f.write("\n") + + # Merge-mode ablation (K=30 only). + if merge_csv is not None and merge_rows: + f.write("## Merge-mode ablation (K=30, merge stage in isolation)\n\n") + f.write("Source: `topk_output_adaptive_workspace_ablation` modes 5/6/7/11.\n\n") + with merge_csv.open() as g: + f.write("```\n" + g.read() + "```\n\n") + + # Region analysis. + f.write("## Parallel-advantage region analysis\n\n") + f.write(f"Win threshold: speedup_vs_sglang >= {WIN_THRESHOLD}.\n\n") + with advantage_csv.open() as g: + f.write("```\n" + g.read() + "```\n\n") + + # Adversarial correctness. + if adversarial_rows is not None: + f.write("## Adversarial correctness cases\n\n") + f.write("| case | fused | adapt s=1 | adapt s=4 |\n") + f.write("|---|:-:|:-:|:-:|\n") + for r in adversarial_rows: + f.write(f"| {r['case']} | {'PASS' if r['fused_ok'] else 'FAIL'} | " + f"{'PASS' if r['adapt_split1_ok'] else 'FAIL'} | " + f"{'PASS' if r['adapt_split4_ok'] else 'FAIL'} |\n") + f.write("\n") + + # Recommended dispatch policy + f.write("## Recommended production dispatch policy\n\n") + # Compute best split per (K bucket, pages) by majority best_split. + best_by_bucket = {} + for r in rows: + if not (r.method == "adaptive_merge" and r.correctness + and r.mean_us is not None + and r.actual_path.startswith("adaptive_")): + continue + key = (k_bucket(r.topk), r.pages, r.batch, r.mapping_name) + ent = best_by_bucket.setdefault(key, {"best": (float('inf'), None)}) + if r.mean_us < ent["best"][0]: + ent["best"] = (r.mean_us, r.requested_split) + bucket_splits = {} + for (bucket, pages, B, mapping), v in best_by_bucket.items(): + bucket_splits.setdefault((bucket, pages, B, mapping), []).append(v["best"][1]) + f.write("| K_bucket | pages | B | mapping | recommended_split |\n") + f.write("|---|---:|---:|---|---:|\n") + for key, splits_list in sorted(bucket_splits.items()): + bucket, pages, B, mapping = key + try: + rec = statistics.mode(splits_list) + except statistics.StatisticsError: + rec = sorted(splits_list)[0] + f.write(f"| {bucket} | {pages} | {B} | {mapping} | {rec} |\n") + f.write("\n- For `large_K(>=1024)` adaptive entry routes to fused (zero-overhead).\n") + f.write("- For `mid_K(64-512)` adaptive entry currently routes to fused; ") + f.write("a dedicated mid-K kernel is future work.\n") + + # Failures. + if failed: + f.write("\n## Correctness failures\n\n") + f.write("| pages | K | B | mapping | method | split | actual_path | notes |\n") + f.write("|---:|---:|---:|---|---|---:|---|---|\n") + for r in failed: + f.write(f"| {r.pages} | {r.topk} | {r.batch} | {r.mapping_name} | " + f"{r.method} | {r.requested_split} | {r.actual_path} | " + f"{r.notes} |\n") + f.write(f"\n## Files\n- raw csv: `{raw_csv}`\n") + f.write(f"- best adaptive csv: `{best_csv}`\n") + f.write(f"- advantage summary csv: `{advantage_csv}`\n") + print(f"wrote {path}") + + +# --------------------------------------------------------------------------- # +# CLI. +# --------------------------------------------------------------------------- # +def parse_args(): + p = argparse.ArgumentParser(description=__doc__) + p.add_argument("--pages", type=int, nargs="+", default=DEFAULT_PAGES) + p.add_argument("--ks", type=int, nargs="+", default=DEFAULT_KS) + p.add_argument("--batches", type=int, nargs="+", default=DEFAULT_BATCHES) + p.add_argument("--splits", type=int, nargs="+", default=DEFAULT_SPLITS) + p.add_argument("--mappings", type=str, nargs="+", default=DEFAULT_MAPPING_NAMES, + help="Mapping mode names (NONE, TRUNC8, POWER, LOG, ASINH, LOG1P, ERF, TANH).") + p.add_argument("--mapping-power", type=float, default=0.5) + p.add_argument("--dtypes", type=str, nargs="+", default=DEFAULT_DTYPES) + p.add_argument("--local-mode", type=int, default=LOCAL_SELECT32_SORT32, + choices=[LOCAL_BLOCK_FULL_SORT, LOCAL_SELECT32_SORT32], + help="0=BLOCK_FULL_SORT, 1=SELECT32_SORT32 (default).") + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--repeat", type=int, default=200) + p.add_argument("--output-dir", type=Path, + default=Path("bench_results") / time.strftime("setting_sweep_%Y%m%d_%H%M%S")) + p.add_argument("--print-tables", action="store_true", + help="Also print per-K latency tables to stdout (large output).") + p.add_argument("--merge-ablation", action="store_true", default=True, + help="Run merge-mode ablation sub-sweep (K=30 only). Default: on.") + p.add_argument("--no-merge-ablation", dest="merge_ablation", action="store_false", + help="Skip the merge-mode ablation sub-sweep.") + p.add_argument("--adversarial", action="store_true", default=True, + help="Run adversarial correctness cases (default: on).") + p.add_argument("--no-adversarial", dest="adversarial", action="store_false", + help="Skip adversarial correctness cases.") + return p.parse_args() + + +def main(): + args = parse_args() + if not torch.cuda.is_available(): + sys.exit("CUDA is required.") + args.output_dir.mkdir(parents=True, exist_ok=True) + + device_info = dict( + name=torch.cuda.get_device_name(0), + sm_count=torch.cuda.get_device_properties(0).multi_processor_count, + ) + print(f"Device: {device_info['name']} SMs={device_info['sm_count']}") + print(f"Output dir: {args.output_dir.resolve()}") + + # Validate mappings. + mapping_modes = [] + for name in args.mappings: + if name not in MAPPING_BY_NAME: + sys.exit(f"unknown mapping name: {name} (valid: {list(MAPPING_BY_NAME)})") + mapping_modes.append(MAPPING_BY_NAME[name]) + + # Pre-allocate workspace large enough for the largest configuration. + B_max = max(args.batches) + ws = make_workspace(B_max=B_max, max_split=max(args.splits), K_local=32) + + configs = [(pages, K, B, mode, dtype) + for pages in args.pages + for K in args.ks + for B in args.batches + for mode in mapping_modes + for dtype in args.dtypes] + print(f"Configs: {len(configs)} (each runs fused + cub + {len(args.splits)} adaptive)") + print(f"Splits: {args.splits} warmup={args.warmup} repeat={args.repeat}") + + rows: List[Row] = [] + t0 = time.time() + for i, cfg in enumerate(configs, 1): + pages, K, B, mode, dtype = cfg + if i % 5 == 0 or i == 1: + print(f"[{i:3d}/{len(configs)}] pages={pages} K={K} B={B} " + f"mapping={MAPPING_NAMES[mode]} dtype={dtype} " + f"(elapsed {time.time()-t0:.1f}s)") + rows.extend(run_one_setting( + pages, K, B, mode, args.mapping_power, dtype, + args.splits, args.local_mode, + args.warmup, args.repeat, ws, + device_info["name"], device_info["sm_count"], + )) + print(f"\nSweep complete in {time.time()-t0:.1f}s. rows={len(rows)}") + + # Output files. + raw_csv = args.output_dir / "topk_setting_sweep_raw.csv" + best_csv = args.output_dir / "topk_setting_sweep_best_adaptive.csv" + advantage_csv = args.output_dir / "topk_parallel_advantage_summary.csv" + merge_csv = args.output_dir / "topk_merge_mode_summary.csv" + adversarial_csv = args.output_dir / "topk_adversarial_correctness.csv" + report_md = args.output_dir / "topk_setting_sweep_report.md" + + write_raw_csv(rows, raw_csv) + write_best_adaptive_csv(rows, best_csv) + write_advantage_summary_csv(rows, advantage_csv) + + # Merge-mode ablation (K=30 only). + merge_rows = [] + if args.merge_ablation: + print("\nMerge-mode ablation (K=30, ablation harness):") + merge_rows = run_merge_ablation( + pages_list=args.pages, batches=args.batches, + splits=args.splits, warmup=args.warmup, repeat=args.repeat, + ws=ws, device_name=device_info["name"], sm_count=device_info["sm_count"], + ) + print(f" collected {len(merge_rows)} merge-only timings") + write_merge_mode_csv(merge_rows, merge_csv) + + # Adversarial correctness check. + adv_rows = [] + if args.adversarial: + print("\nAdversarial correctness check (K=30, MAPPING_NONE, B=2, pages=4096):") + adv_rows = adversarial_correctness_test(args.local_mode) + for r in adv_rows: + print(f" case={r['case']:>22} fused={r['fused_ok']} " + f"adapt_s1={r['adapt_split1_ok']} adapt_s4={r['adapt_split4_ok']}") + with adversarial_csv.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(["case", "fused_ok", "fused_note", + "adapt_split1_ok", "adapt_split1_note", + "adapt_split4_ok", "adapt_split4_note"]) + for r in adv_rows: + w.writerow([r["case"], r["fused_ok"], r["fused_note"], + r["adapt_split1_ok"], r["adapt_split1_note"], + r["adapt_split4_ok"], r["adapt_split4_note"]]) + print(f"wrote {adversarial_csv}") + + write_markdown_report(rows, device_info, args, report_md, + args.splits, best_csv, advantage_csv, raw_csv, + merge_csv=merge_csv, merge_rows=merge_rows, + adversarial_csv=adversarial_csv, adversarial_rows=adv_rows) + + # Optional terminal tables. + if args.print_tables: + print_per_K_tables(rows, args.splits) + + # Short summary. + print("\n" + "=" * 70) + print(f"Files written under: {args.output_dir.resolve()}") + print(f" raw : {raw_csv.name}") + print(f" best_adapt: {best_csv.name}") + print(f" advantage : {advantage_csv.name}") + print(f" report : {report_md.name}") + failed = [r for r in rows if r.correctness is False] + print(f"Correctness failures: {len(failed)}") + if failed: + for r in failed[:10]: + print(f" - pages={r.pages} K={r.topk} B={r.batch} " + f"map={r.mapping_name} method={r.method} split={r.requested_split} " + f"path={r.actual_path}: {r.notes}") + if len(failed) > 10: + print(f" ... and {len(failed) - 10} more (see raw csv)") + # Quick win-region rollup. + n_adaptive = sum(1 for r in rows if r.method == "adaptive_merge" + and r.actual_path.startswith("adaptive_") and r.correctness) + n_wins = sum(1 for r in rows if r.method == "adaptive_merge" + and r.actual_path.startswith("adaptive_") and r.correctness + and r.speedup_vs_sglang_fused is not None + and r.speedup_vs_sglang_fused >= WIN_THRESHOLD) + print(f"Adaptive-rows (correct, real adaptive path): {n_adaptive}") + print(f" -> wins vs fused (>= {WIN_THRESHOLD}x): {n_wins} " + f"({100.0 * n_wins / max(n_adaptive,1):.1f}%)") + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_adaptive_overhead.py b/benchmarks/profile_adaptive_overhead.py new file mode 100644 index 00000000..5c0b2118 --- /dev/null +++ b/benchmarks/profile_adaptive_overhead.py @@ -0,0 +1,212 @@ +"""Decompose the adaptive split-2 TopK kernel's latency into Phase-1, +Phase-2, and barrier+launch overhead — and compare against the naive +CUB sort (topk.cu) and the single-CTA radix baseline (topk_sglang.cu). + +No remap (mode=0), bfloat16 scores only, to keep the comparison clean. + +Usage: + python benchmarks/profile_adaptive_overhead.py [--gpu 4] +""" +from __future__ import annotations + +import argparse +import json +import math +from typing import Dict, List + +import torch + +from vortex_torch_C import ( + topk_output, + topk_output_sglang, + topk_output_adaptive, + topk_adaptive_phase1_only, + topk_adaptive_phase2_only, +) + + +def make_inputs(bs: int, pages: int, K: int, reserved_bos: int = 1, reserved_eos: int = 1, + device: str = "cuda") -> Dict[str, torch.Tensor]: + per_row = pages + reserved_bos + reserved_eos + dense_kv_indptr = torch.arange( + 0, (bs + 1) * per_row, per_row, device=device, dtype=torch.int32) + dense_kv_indices = torch.arange(bs * per_row, device=device, dtype=torch.int32) + per_sparse = K + reserved_bos + reserved_eos + sparse_kv_indptr = torch.arange( + 0, (bs + 1) * per_sparse, per_sparse, device=device, dtype=torch.int32) + sparse_kv_indices = torch.zeros(bs * per_sparse, device=device, dtype=torch.int32) + x = torch.randn(bs * per_row, device=device, dtype=torch.bfloat16) + partial_scores = torch.empty(bs * 2 * K, device=device, dtype=torch.float32) + partial_indices = torch.empty(bs * 2 * K, device=device, dtype=torch.int32) + return dict( + x=x, + dense_kv_indptr=dense_kv_indptr, + dense_kv_indices=dense_kv_indices, + sparse_kv_indptr=sparse_kv_indptr, + sparse_kv_indices=sparse_kv_indices, + partial_scores=partial_scores, + partial_indices=partial_indices, + ) + + +def time_kernel(fn, args, warmup: int = 20, repeat: int = 200) -> float: + """Return mean ms.""" + for _ in range(warmup): + fn(*args) + torch.cuda.synchronize() + starts = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + ends = [torch.cuda.Event(enable_timing=True) for _ in range(repeat)] + for i in range(repeat): + starts[i].record() + fn(*args) + ends[i].record() + torch.cuda.synchronize() + times = [s.elapsed_time(e) for s, e in zip(starts, ends)] + return sum(times) / len(times) + + +def run_config(bs: int, pages: int, K: int, reserved_bos: int = 1, reserved_eos: int = 1, + warmup: int = 20, repeat: int = 200) -> Dict[str, float]: + inp = make_inputs(bs, pages, K, reserved_bos, reserved_eos) + + # --- baseline: topk_output_sglang (single-CTA radix-select, mode=0) --- + sglang_args = ( + inp["x"], inp["dense_kv_indptr"], inp["sparse_kv_indptr"], + inp["dense_kv_indices"], inp["sparse_kv_indices"], + bs, K, reserved_bos, reserved_eos, pages, + ) + sglang_ms = time_kernel(topk_output_sglang, sglang_args, warmup, repeat) + + # --- naive CUB sort: topk_output (only if pages <= 8192 — template ladder limit) --- + naive_ms = float("nan") + if pages <= 8192: + naive_args = ( + inp["x"], inp["dense_kv_indptr"], inp["dense_kv_indices"], + inp["sparse_kv_indptr"], inp["sparse_kv_indices"], + bs, K, reserved_bos, reserved_eos, pages, + ) + try: + naive_ms = time_kernel(topk_output, naive_args, warmup, repeat) + except RuntimeError as e: + print(f"[naive skip] bs={bs} pages={pages} K={K}: {e}") + + # --- adaptive full --- + adaptive_args = ( + inp["x"], inp["dense_kv_indptr"], inp["sparse_kv_indptr"], + inp["dense_kv_indices"], inp["sparse_kv_indices"], + bs, K, reserved_bos, reserved_eos, pages, + 0, # mapping_mode = NONE + 0.5, # mapping_power (unused) + ) + adaptive_ms = time_kernel(topk_output_adaptive, adaptive_args, warmup, repeat) + + # --- adaptive Phase 1 only --- + p1_args = ( + inp["x"], inp["dense_kv_indptr"], inp["dense_kv_indices"], + inp["partial_scores"], inp["partial_indices"], + bs, K, reserved_bos, reserved_eos, pages, + ) + p1_ms = time_kernel(topk_adaptive_phase1_only, p1_args, warmup, repeat) + + # --- adaptive Phase 2 only (workspace pre-populated by the last p1 call) --- + p2_args = ( + inp["partial_scores"], inp["partial_indices"], + inp["sparse_kv_indptr"], inp["sparse_kv_indices"], + bs, K, reserved_bos, + ) + p2_ms = time_kernel(topk_adaptive_phase2_only, p2_args, warmup, repeat) + + overhead_ms = adaptive_ms - (p1_ms + p2_ms) + + return { + "bs": bs, "pages": pages, "K": K, + "naive_ms": naive_ms, + "sglang_ms": sglang_ms, + "adaptive_ms": adaptive_ms, + "phase1_ms": p1_ms, + "phase2_ms": p2_ms, + "p1_plus_p2_ms": p1_ms + p2_ms, + "overhead_ms": overhead_ms, + "overhead_frac": overhead_ms / adaptive_ms if adaptive_ms else 0.0, + "adaptive_vs_sglang": adaptive_ms / sglang_ms if sglang_ms else float("nan"), + } + + +def _fmt(v, w=9): + if isinstance(v, float) and math.isnan(v): + return f"{'—':>{w}s}" + if isinstance(v, float): + return f"{v:>{w}.4f}" + return f"{str(v):>{w}s}" + + +def print_table(rows: List[dict]) -> None: + hdr = (f"{'bs':>3s} {'pages':>6s} {'K':>5s} {'naive':>9s} {'sglang':>9s} " + f"{'adaptive':>9s} {'phase1':>9s} {'phase2':>9s} {'p1+p2':>9s} " + f"{'overhead':>9s} {'ovh%':>6s} {'a/sglang':>9s}") + sep = "-" * len(hdr) + print(sep) + print(hdr) + print(sep) + for r in rows: + ovh_pct = 100.0 * r["overhead_frac"] + print(f"{r['bs']:>3d} {r['pages']:>6d} {r['K']:>5d} " + f"{_fmt(r['naive_ms'])} {_fmt(r['sglang_ms'])} " + f"{_fmt(r['adaptive_ms'])} {_fmt(r['phase1_ms'])} {_fmt(r['phase2_ms'])} " + f"{_fmt(r['p1_plus_p2_ms'])} {_fmt(r['overhead_ms'])} " + f"{ovh_pct:>5.1f}% {r['adaptive_vs_sglang']:>8.3f}×") + print(sep) + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--gpu", type=int, default=4) + p.add_argument("--warmup", type=int, default=20) + p.add_argument("--repeat", type=int, default=200) + p.add_argument("--output-json", type=str, default=None) + args = p.parse_args() + + torch.cuda.set_device(args.gpu) + + # Sweep: small/medium/large bs × pages × K matrix exercising both + # the light path (K=30) and heavy path (K=2048). + configs = [ + # bs, pages, K + (1, 4096, 30), + (1, 16384, 30), + (1, 32768, 30), + (4, 4096, 30), + (4, 16384, 30), + (4, 32768, 30), + (16, 4096, 30), + (16, 32768, 30), + # heavy + (1, 4096, 2048), + (1, 16384, 2048), + (1, 32768, 2048), + (4, 4096, 2048), + (4, 16384, 2048), + (4, 32768, 2048), + (16, 4096, 2048), + (16, 32768, 2048), + ] + + rows = [] + for (bs, pages, K) in configs: + try: + row = run_config(bs, pages, K, warmup=args.warmup, repeat=args.repeat) + rows.append(row) + print(f"[done] bs={bs} pages={pages} K={K}") + except RuntimeError as e: + print(f"[skip] bs={bs} pages={pages} K={K}: {e}") + + print_table(rows) + + if args.output_json: + with open(args.output_json, "w") as f: + json.dump(rows, f, indent=2) + print(f"Saved: {args.output_json}") + + +if __name__ == "__main__": + main() diff --git a/csrc/topk_sglang_cluster.cu b/csrc/archived/topk_sglang_cluster.cu similarity index 100% rename from csrc/topk_sglang_cluster.cu rename to csrc/archived/topk_sglang_cluster.cu diff --git a/csrc/archived/topk_sglang_parallel.cu b/csrc/archived/topk_sglang_parallel.cu new file mode 100644 index 00000000..f11e59d3 --- /dev/null +++ b/csrc/archived/topk_sglang_parallel.cu @@ -0,0 +1,639 @@ +/** + * Vortex TopK — single-kernel parallel+merge pipeline. + * + * ONE kernel launch. Per-chunk selection and cross-chunk merge both run + * inside the same grid-(N, Batch) launch. The last-arriving CTA for + * each batch (detected by a program-lifetime __device__ done-counter + + * atomicInc wrap-around) carries out the merge — no second launch, no + * per-call cudaMemset for barrier state. + * + * Correctness: + * Stage 1 per-chunk uses ONE 8-bit radix histogram + ONE 8-bit + * refinement round on the threshold bin (16 bits of selection + * precision). For bf16 input (8 mantissa bits effective), this is + * lossless — two items with the same 16-bit key are bit-identical as + * bf16 values. + * + * Stage 2 merge operates on N*K pre-remapped keys in shared memory + * and uses the same 8-bit-hist + 8-bit-refine pattern, which is + * strictly sufficient to pick the correct top-K from the union. + * + * Low-overhead primitives: + * - Warp-level ballot+popc compaction on the "bin > threshold" path + * so each warp issues ONE atomicAdd on the block counter instead + * of one per thread. + * - Program-lifetime __device__ done-counter sized for realistic + * batch×head counts; atomicInc wraps back to 0 at num_chunks so + * there's no memset on the hot path. + * - Vectorised float4/int4 loads from global → smem in the merge. + * + * Supported mapping modes (IDs from csrc/topk_mapping.cuh): + * 3=POWER, 6=ASINH, 7=LOG1P, 9=ERF, 10=TANH, 11=SUBTRACT, + * 13=EXP_STRETCH, 15=SHIFT_POW2, 16=SHIFT_POW3, 17=LINEAR_STEEP. + */ + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + #include "register.h" + + namespace { + + // ---- Launch constants ------------------------------------------------------ + + constexpr int kThreadsPerBlock = 1024; + constexpr int kWarpSize = 32; + constexpr int RADIX = 256; + constexpr size_t kMaxDynSmem = 96 * 1024; + constexpr int VORTEX_MAX_TOPK = 2048; + + // Stage-2 holds N*K (key, idx) pairs in smem = 8 B/item. + constexpr int kMergeCap = 8192; + + // Max batch the single kernel can sequence. Sized for realistic + // bs×heads (decode). __device__ globals are zero-initialised at + // program start; atomicInc wrap-around keeps each entry at 0 between + // launches, so no host-side memset on the hot path. + constexpr int kMaxBatch = 8192; + __device__ unsigned int g_done_counter[kMaxBatch]; + + // ---- Device helpers -------------------------------------------------------- + + __device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + // Required symbol for topk_mapping.cuh's compute_stage1_bin. Not used + // directly by the kernel body here, but the header includes a forward + // declaration that resolves against this definition at link time. + __device__ __forceinline__ uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + template + __device__ __forceinline__ float vortex_to_float(T x); + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + #include "topk_mapping.cuh" + + // ============================================================================ + // 8-step suffix cumsum over 256 bins. After the call s_hist[0][i] is + // the count of items with bin >= i (monotone non-increasing). + // ============================================================================ + __device__ __forceinline__ void run_cumsum_256(int s_hist[2][RADIX + 128]) { + const int tx = threadIdx.x; + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + static_assert(1 << 8 == RADIX); + if (C10_LIKELY(tx < RADIX)) { + const int j = 1 << i; + const int k = i & 1; + int value = s_hist[k][tx]; + if (tx < RADIX - j) value += s_hist[k][tx + j]; + s_hist[k ^ 1][tx] = value; + } + __syncthreads(); + } + } + + // ============================================================================ + // Warp-level ballot+popc compaction. + // + // Every participating thread offers a boolean `selected`. Exactly ONE + // atomicAdd per warp — issued by the first active lane — reserves + // `warp_count` slots; other selected lanes derive their slot via a + // popc prefix sum. Safe when called from inside a divergent region + // (uses __activemask(), not a fixed all-ones mask). + // ============================================================================ + __device__ __forceinline__ int warp_compact_slot(bool selected, int* s_counter) { + const uint32_t mask = __activemask(); + const uint32_t ballot = __ballot_sync(mask, selected); + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_count = __popc(ballot); + const int rank_in_warp = __popc(ballot & ((1u << lane) - 1u)); + + const int first_lane = __ffs(mask) - 1; + int base = 0; + if (lane == first_lane) { + base = (warp_count > 0) ? ::atomicAdd(s_counter, warp_count) : 0; + } + base = __shfl_sync(mask, base, first_lane); + return selected ? (base + rank_in_warp) : -1; + } + + // ============================================================================ + // Combined kernel — Stage 1 (per-chunk) + barrier + Stage 2 (merge). + // + // Grid = (Batch, N). One CTA per (batch, chunk). + // Block = kThreadsPerBlock = 1024. + // + // Shared-memory layout (reused across phases): + // Phase 1 needs: + // s_remapped[chunk_size] (float) — cached apply_transform output. + // s_bins[chunk_size] (uint8) — cached coarse bin. + // Merge needs: + // s_scores[N*K] (float) — pair buffer, loaded vectorised. + // s_indices[N*K] (int32) — pair buffer. + // kSmemBytes is sized to host max of both. + // + // Sync between phases: + // After Phase 1's workspace writes, __threadfence() publishes them, + // then thread 0 does `atomicInc(&g_done_counter[bx], N-1)` which + // cycles 0→1→…→N-1→0 so no reset is needed between calls. The CTA + // whose returned `old == N-1` is the last one — it falls through + // into the merge; other CTAs return. + // ============================================================================ + template + __global__ __launch_bounds__(kThreadsPerBlock) + void TopK_Parallel_Kernel( + const ScoreT* __restrict__ score, // [Batch, N, chunk_size] + int32_t* __restrict__ global_idx, // [Batch, K] + float* __restrict__ partial_keys, // [Batch, N, K] workspace + int32_t* __restrict__ partial_idx, // [Batch, N, K] workspace + int N, + int chunk_size, + int K, + float mapping_power) + { + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + // Addresses for this CTA's chunk slice and its slot in the workspace. + const ScoreT* chunk_in = score + (static_cast(b) * N + n) * chunk_size; + float* chunk_keys_out = partial_keys + (static_cast(b) * N + n) * K; + int32_t* chunk_idx_out = partial_idx + (static_cast(b) * N + n) * K; + const int32_t idx_base = n * chunk_size; // batch-local offset + + // ---------------------------------------------------------------- smem + extern __shared__ char smem_raw[]; + + // Shared-memory counters / histogram live in static smem so the + // Phase-1 and merge phases can share the same dynamic pool. + alignas(128) __shared__ int s_hist_buf[2][RADIX + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + alignas(128) __shared__ int s_is_last; + auto& s_hist = s_hist_buf[0]; + + // ========================================================================= + // Phase 1: per-chunk TopK via 8-bit radix + 8-bit refinement. + // ========================================================================= + // + // Dynamic smem region used as: + // s_remapped : chunk_size * 4 B (cached apply_transform output) + // s_bins : chunk_size * 1 B (cached Stage-1 bin) + // + // Refinement is a second 8-bit bucket on bits [23:16] of the + // sign-flipped u32 key, used to refine the threshold bin. 8 + 8 = + // 16 bits of selection precision → lossless for bf16. + float* s_remapped = reinterpret_cast(smem_raw); + uint8_t* s_bins = reinterpret_cast(s_remapped + chunk_size); + + // ---- Degenerate chunk_size <= K : emit everything as-is. ------------- + if (chunk_size <= K) { + for (int i = tx; i < K; i += blockDim.x) { + if (i < chunk_size) { + const float raw = vortex_to_float(chunk_in[i]); + chunk_keys_out[i] = apply_transform_tmpl(raw, mapping_power); + chunk_idx_out [i] = i + idx_base; + } else { + chunk_keys_out[i] = -CUDART_INF_F; + chunk_idx_out [i] = -1; + } + } + } else { + // ---- Histogram pass 1: transform + bucket; cache both to smem. ---- + if (tx < RADIX + 1) s_hist[tx] = 0; + if (tx == 0) { s_counter = 0; s_threshold_bin = -1; s_last_remain = 0; } + __syncthreads(); + + for (int idx = tx; idx < chunk_size; idx += blockDim.x) { + const float raw = vortex_to_float(chunk_in[idx]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t b32 = convert_to_uint32(remapped); + const int bin = (b32 >> 24) & 0xFF; + s_remapped[idx] = remapped; + s_bins [idx] = static_cast(bin); + ::atomicAdd(&s_hist[bin], 1); + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin = s_threshold_bin; + + // ---- Emit bin > threshold (warp-popc) and build refinement hist. ---- + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + const int num_iters = (chunk_size + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + const bool take_above = in_range && (bin > threshold_bin); + + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + chunk_keys_out[slot] = s_remapped[idx]; + chunk_idx_out [slot] = idx + idx_base; + } else if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + // ---- Refinement cumsum → sub-threshold bin. ------------------------ + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + // budget for items at the sub-threshold bin + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + // Only possible if last_remain == 0 (bin > threshold already emitted + // exactly K items). Nothing more to do; make the sub bin a sentinel. + s_sub_threshold_bin = RADIX; // no sub-threshold bin + } + __syncthreads(); + const int sub_threshold_bin = s_sub_threshold_bin; + + // ---- Emit threshold-bin items using sub-threshold logic. ---------- + for (int it = 0; it < num_iters; ++it) { + const int idx = it * blockDim.x + tx; + const bool in_range = (idx < chunk_size); + int bin = -1; + if (in_range) bin = static_cast(s_bins[idx]); + int sub_bin = -1; + if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[idx]); + sub_bin = (b32 >> 16) & 0xFF; + } + + const bool take_sub_above = (sub_bin > sub_threshold_bin); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + chunk_keys_out[slot] = s_remapped[idx]; + chunk_idx_out [slot] = idx + idx_base; + } else if (sub_bin == sub_threshold_bin) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + chunk_keys_out[K - pos] = s_remapped[idx]; + chunk_idx_out [K - pos] = idx + idx_base; + } + } + } + __syncthreads(); + } + + // ========================================================================= + // Barrier: publish this CTA's workspace writes and atomicInc the + // per-batch done-counter. The CTA that sees old == N-1 is the last + // arriving one; every other CTA returns here. + // ========================================================================= + __threadfence(); + __syncthreads(); + if (tx == 0) { + const unsigned int old = ::atomicInc( + &g_done_counter[b], static_cast(N - 1)); + s_is_last = (old == static_cast(N - 1)) ? 1 : 0; + } + __syncthreads(); + if (s_is_last == 0) return; + + // ========================================================================= + // Phase 2 (merge, only in last-arriving CTA): + // load N*K candidates into smem (vectorised) → + // 8-bit histogram in smem → + // threshold → warp-popc emit above + tie-bin refinement. + // ========================================================================= + const int total = N * K; + const float* keys_in = partial_keys + static_cast(b) * total; + const int32_t* idx_in = partial_idx + static_cast(b) * total; + int32_t* out_idx = global_idx + static_cast(b) * K; + + // Reuse the same dynamic smem region as Phase 1 — Phase 1's caches + // are dead now. Layout: [ s_scores : total floats | s_indices : total int32 ]. + float* s_scores = reinterpret_cast(smem_raw); + int32_t* s_indices = reinterpret_cast(s_scores + total); + + // Vectorised 128-bit loads when `total` is a multiple of 4. + if ((total & 3) == 0) { + const float4* keys_v = reinterpret_cast(keys_in); + const int4* idx_v = reinterpret_cast (idx_in); + float4* ss_v = reinterpret_cast (s_scores); + int4* si_v = reinterpret_cast (s_indices); + const int total4 = total >> 2; + for (int i = tx; i < total4; i += blockDim.x) { + ss_v[i] = keys_v[i]; + si_v[i] = idx_v [i]; + } + } else { + for (int i = tx; i < total; i += blockDim.x) { + s_scores [i] = keys_in[i]; + s_indices[i] = idx_in [i]; + } + } + + if (tx < RADIX + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + __syncthreads(); + + // (2) 8-bit histogram in smem. + const int num_iters_m = (total + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + if (i < total && s_indices[i] >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_hist[bin], 1); + } + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + + // Fast path: no threshold search needed when valid_count ≤ K. + const int valid_count = s_hist[0]; + if (valid_count <= K) { + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + const bool take = (i < total) && (s_indices[i] >= 0); + const int slot = warp_compact_slot(take, &s_counter); + if (take) out_idx[slot] = s_indices[i]; + } + return; + } + + if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { + s_threshold_bin = tx; + s_last_remain = K - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin_m = s_threshold_bin; + + // (3) Emit above threshold via warp-popc; build sub-bin histogram on + // bits [23:16] for the tie-bin refinement. + if (tx < RADIX + 1) s_hist[tx] = 0; + __syncthreads(); + + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_valid = false; + int bin = -1; + uint32_t b32 = 0; + if (i < total) { + const int32_t idx = s_indices[i]; + if (idx >= 0) { + in_valid = true; + b32 = convert_to_uint32(s_scores[i]); + bin = (b32 >> 24) & 0xFF; + } + } + const bool take_above = in_valid && (bin > threshold_bin_m); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + out_idx[slot] = s_indices[i]; + } else if (in_valid && bin == threshold_bin_m) { + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + // (4) Refinement cumsum → sub-threshold bin. + run_cumsum_256(s_hist_buf); + if (tx < RADIX && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) { + s_sub_threshold_bin = RADIX; // no tie-bin refinement needed + } + __syncthreads(); + const int sub_threshold_bin_m = s_sub_threshold_bin; + + // (5) Emit tie-bin items via warp-popc + sub-threshold budget. + for (int it = 0; it < num_iters_m; ++it) { + const int i = it * blockDim.x + tx; + bool in_threshold = false; + int sub_bin = -1; + if (i < total) { + const int32_t idx = s_indices[i]; + if (idx >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + if (bin == threshold_bin_m) { + in_threshold = true; + sub_bin = (b32 >> 16) & 0xFF; + } + } + } + const bool take_sub_above = in_threshold && (sub_bin > sub_threshold_bin_m); + const int slot = warp_compact_slot(take_sub_above, &s_counter); + if (take_sub_above) { + out_idx[slot] = s_indices[i]; + } else if (in_threshold && sub_bin == sub_threshold_bin_m) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) out_idx[K - pos] = s_indices[i]; + } + } + } + + // ---- setup_kernel_smem_once ------------------------------------------------ + + template + void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + return ::cudaFuncSetAttribute( + f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + }(); + TORCH_CHECK(result == cudaSuccess, + "fast_fused_topk_merge setup failed: ", + ::cudaGetErrorString(result)); + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + // ============================================================================ + // Host entry point. + // + // score [batch_size, num_chunks, chunk_size] bf16 or f32 + // global_topk_indices [batch_size, topk_val] int32 (output) + // + // ONE kernel launch. The per-chunk selection (Phase 1) and the + // cross-chunk merge (Phase 2) are fused in TopK_Parallel_Kernel via a + // last-CTA-wins atomicInc barrier. A per-call workspace holds the + // [batch, N, K] partial top-K that the last CTA reads from; the + // done-counter is a program-lifetime __device__ global so nothing + // needs memsetting on the hot path. + // ============================================================================ + void fast_fused_topk_merge( + const at::Tensor& score, + at::Tensor& global_topk_indices, + const int64_t batch_size, + const int64_t num_chunks, + const int64_t chunk_size, + const int64_t topk_val, + const int64_t mapping_mode, + const double mapping_power) + { + CHECK_CUDA(score); + CHECK_CUDA(global_topk_indices); + + TORCH_CHECK(topk_val > 0 && topk_val <= VORTEX_MAX_TOPK, + "fast_fused_topk_merge: topk_val=", topk_val, + " must be in (0, ", VORTEX_MAX_TOPK, "]"); + TORCH_CHECK(num_chunks >= 1, "num_chunks must be >= 1"); + TORCH_CHECK(batch_size >= 1, "batch_size must be >= 1"); + TORCH_CHECK(batch_size <= kMaxBatch, + "fast_fused_topk_merge: batch_size ", batch_size, + " exceeds the __device__ done-counter cap (", kMaxBatch, ")"); + TORCH_CHECK(chunk_size >= 1, "chunk_size must be >= 1"); + TORCH_CHECK(num_chunks * topk_val <= kMergeCap, + "fast_fused_topk_merge: num_chunks*topk_val (", + num_chunks * topk_val, ") exceeds merge cap (", kMergeCap, + "). Reduce num_chunks or topk_val."); + TORCH_CHECK(global_topk_indices.scalar_type() == at::kInt, + "global_topk_indices must be int32"); + TORCH_CHECK(global_topk_indices.numel() >= batch_size * topk_val, + "global_topk_indices is too small for batch_size * topk_val"); + + TORCH_CHECK( + mapping_mode == MAPPING_POWER || + mapping_mode == MAPPING_ASINH || + mapping_mode == MAPPING_LOG1P || + mapping_mode == MAPPING_ERF || + mapping_mode == MAPPING_TANH || + mapping_mode == MAPPING_SUBTRACT || + mapping_mode == MAPPING_EXP_STRETCH || + mapping_mode == MAPPING_SHIFT_POW2 || + mapping_mode == MAPPING_SHIFT_POW3 || + mapping_mode == MAPPING_LINEAR_STEEP, + "fast_fused_topk_merge: mapping_mode=", mapping_mode, + " not supported. Valid: POWER(3), ASINH(6), LOG1P(7), ERF(9), " + "TANH(10), SUBTRACT(11), EXP_STRETCH(13), SHIFT_POW2(15), " + "SHIFT_POW3(16), LINEAR_STEEP(17)."); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + // Dynamic smem must fit whichever phase is larger: + // Phase 1: chunk_size floats + chunk_size bytes. + // Phase 2: num_chunks*topk_val * (float + int32). + const size_t p1_bytes = static_cast(chunk_size) * sizeof(float) + + ((static_cast(chunk_size) + 15) & ~size_t(15)); + const size_t p2_bytes = static_cast(num_chunks) * + static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + const size_t smem_bytes = p1_bytes > p2_bytes ? p1_bytes : p2_bytes; + TORCH_CHECK(smem_bytes <= kMaxDynSmem, + "fast_fused_topk_merge: smem ", smem_bytes, + " > ceiling ", kMaxDynSmem); + + // Per-call workspace for the [batch, N, K] partial top-K. at::empty + // hits the caching allocator (no cudaMalloc in the hot path after + // warmup). The done-counter lives in __device__ memory — no memset. + auto opts_f32 = at::TensorOptions().device(score.device()).dtype(at::kFloat); + auto opts_i32 = at::TensorOptions().device(score.device()).dtype(at::kInt); + const int64_t ws_elems = batch_size * num_chunks * topk_val; + at::Tensor partial_keys = at::empty({ws_elems}, opts_f32); + at::Tensor partial_idx = at::empty({ws_elems}, opts_i32); + + dim3 grid(static_cast(batch_size), + static_cast(num_chunks)); + dim3 block(kThreadsPerBlock); + + #define LAUNCH(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + setup_kernel_smem_once, \ + kMaxDynSmem>(); \ + TopK_Parallel_Kernel \ + <<>>( \ + PTR_EXPR, \ + global_topk_indices.data_ptr(), \ + partial_keys.data_ptr(), \ + partial_idx.data_ptr(), \ + static_cast(num_chunks), \ + static_cast(chunk_size), \ + static_cast(topk_val), \ + mp); \ + } while (0) + + #define DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + switch (mapping_mode) { \ + case MAPPING_POWER: LAUNCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_ASINH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + default: TORCH_CHECK(false, "unreachable mode"); \ + } \ + } while (0) + + if (score.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MODE(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(score.data_ptr())); + } else if (score.scalar_type() == at::ScalarType::Float) { + DISPATCH_MODE(float, score.data_ptr()); + } else { + TORCH_CHECK(false, "fast_fused_topk_merge: unsupported dtype ", + score.scalar_type()); + } + + #undef DISPATCH_MODE + #undef LAUNCH + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "fast_fused_topk_merge kernel failed: ", ::cudaGetErrorString(rc)); + } \ No newline at end of file diff --git a/csrc/register.cc b/csrc/register.cc index cf1be880..9771bbd9 100644 --- a/csrc/register.cc +++ b/csrc/register.cc @@ -30,16 +30,61 @@ PYBIND11_MODULE(vortex_torch_C, m){ py::arg("mapping_power"), py::arg("mapping_lut") = py::none(), py::arg("mapping_quantiles") = py::none()); - m.def("fast_fused_topk_merge", &fast_fused_topk_merge, - py::arg("score"), py::arg("global_topk_indices"), - py::arg("batch_size"), py::arg("num_chunks"), - py::arg("chunk_size"), py::arg("topk_val"), - py::arg("mapping_mode"), py::arg("mapping_power")); - m.def("fast_cluster_topk_merge", &fast_cluster_topk_merge, - py::arg("score"), py::arg("global_topk_indices"), - py::arg("batch_size"), py::arg("num_chunks"), - py::arg("chunk_size"), py::arg("topk_val"), - py::arg("mapping_mode"), py::arg("mapping_power")); + m.def("topk_output_adaptive", &topk_output_adaptive, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode"), + py::arg("mapping_power")); + m.def("topk_output_adaptive_workspace", &topk_output_adaptive_workspace, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("partial_keys"), py::arg("partial_indices"), + py::arg("done_counter"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("forced_splits") = -1, + py::arg("forced_partition") = -1, + py::arg("local_mode") = 0); + m.def("topk_output_adaptive_workspace_midk", + &topk_output_adaptive_workspace_midk, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("partial_keys"), py::arg("partial_indices"), + py::arg("done_counter"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("mapping_mode"), + py::arg("mapping_power"), + py::arg("forced_splits") = -1); + m.def("topk_output_adaptive_workspace_ablation", + &topk_output_adaptive_workspace_ablation, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("sparse_kv_indptr"), + py::arg("dense_kv_indices"), py::arg("sparse_kv_indices"), + py::arg("partial_keys"), py::arg("partial_indices"), + py::arg("done_counter"), py::arg("scratch"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages"), + py::arg("ablation_mode"), + py::arg("forced_splits") = 8); + m.def("topk_adaptive_phase1_only", &topk_adaptive_phase1_only, + py::arg("x"), py::arg("dense_kv_indptr"), py::arg("dense_kv_indices"), + py::arg("partial_scores"), py::arg("partial_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos"), py::arg("reserved_eos"), + py::arg("max_num_pages")); + m.def("topk_adaptive_phase2_only", &topk_adaptive_phase2_only, + py::arg("partial_scores"), py::arg("partial_indices"), + py::arg("sparse_kv_indptr"), py::arg("sparse_kv_indices"), + py::arg("eff_batch_size"), py::arg("topk_val"), + py::arg("reserved_bos")); m.def("topk_remap_only", &topk_remap_only, py::arg("x"), py::arg("dense_kv_indptr"), py::arg("remapped"), diff --git a/csrc/register.h b/csrc/register.h index ad2baaca..b565c3ef 100644 --- a/csrc/register.h +++ b/csrc/register.h @@ -126,33 +126,101 @@ std::optional mapping_lut = std::nullopt, std::optional mapping_quantiles = std::nullopt ); -// Two-stage parallel TopK. See csrc/topk_sglang_parallel.cu. -// score: [batch_size, num_chunks, chunk_size] bfloat16 or float32 -// global_topk_indices: [batch_size, topk_val] int32 (output) -// Caller must ensure num_chunks * topk_val <= 8192 (merge smem cap). -void fast_fused_topk_merge( -const at::Tensor& score, -at::Tensor& global_topk_indices, -const int64_t batch_size, -const int64_t num_chunks, -const int64_t chunk_size, +void topk_output_adaptive( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, const int64_t mapping_mode, const double mapping_power ); -// Hopper TBC+DSMEM fused TopK merge. See csrc/topk_sglang_cluster.cu. -// Same signature as fast_fused_topk_merge; num_chunks is the cluster -// size and is capped at 8 (portable TBC). Requires sm_90+. -void fast_cluster_topk_merge( -const at::Tensor& score, -at::Tensor& global_topk_indices, -const int64_t batch_size, -const int64_t num_chunks, -const int64_t chunk_size, +void topk_output_adaptive_workspace( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& partial_keys, +at::Tensor& partial_indices, +at::Tensor& done_counter, +const int64_t eff_batch_size, const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, const int64_t mapping_mode, -const double mapping_power +const double mapping_power, +const int64_t forced_splits, +const int64_t forced_partition, +const int64_t local_mode +); + +void topk_output_adaptive_workspace_midk( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& partial_keys, +at::Tensor& partial_indices, +at::Tensor& done_counter, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t mapping_mode, +const double mapping_power, +const int64_t forced_splits +); + +void topk_output_adaptive_workspace_ablation( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& sparse_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& sparse_kv_indices, +at::Tensor& partial_keys, +at::Tensor& partial_indices, +at::Tensor& done_counter, +at::Tensor& scratch, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages, +const int64_t ablation_mode, +const int64_t forced_splits +); + +void topk_adaptive_phase1_only( +const at::Tensor& x, +const at::Tensor& dense_kv_indptr, +const at::Tensor& dense_kv_indices, +at::Tensor& partial_scores, +at::Tensor& partial_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos, +const int64_t reserved_eos, +const int64_t max_num_pages +); + +void topk_adaptive_phase2_only( +const at::Tensor& partial_scores, +const at::Tensor& partial_indices, +const at::Tensor& sparse_kv_indptr, +at::Tensor& sparse_kv_indices, +const int64_t eff_batch_size, +const int64_t topk_val, +const int64_t reserved_bos ); void topk_remap_only( diff --git a/csrc/topk_adaptive_profile.cu b/csrc/topk_adaptive_profile.cu new file mode 100644 index 00000000..fc706da9 --- /dev/null +++ b/csrc/topk_adaptive_profile.cu @@ -0,0 +1,1145 @@ +/** + * Profile-only fixtures for the adaptive split TopK. Two distinct fixtures: + * + * [1] LEGACY split-2 histogram fixture (TopK_Phase1_Only_Kernel, + * TopK_Phase2_Only_Kernel, topk_adaptive_phase1_only, + * topk_adaptive_phase2_only). + * Implements an 8-bit coarse histogram + 8-bit refinement on bits [23:16] + * with a fixed split count of 2 (kNumSplits=2, kThreads=1024). + * Kept for historical comparison only; this is NOT the K=30 production + * path and is NOT representative of the current split kernel in + * topk_sglang_merge.cu. barrier = full_adaptive - (phase1 + phase2). + * + * [2] K=30 ablation fixture (Ablation_*_Kernel, + * topk_output_adaptive_workspace_ablation). + * Measures isolated costs: local sort, workspace write, atomic/fence, + * merge-only (multiple variants), memset-only, and the full adaptive path. + * Split configs exactly match production (kAblCfg1..kAblCfg32). + * ScoreT is bf16 only; mode is hardcoded to MAPPING_NONE. + * Production kernel lives in topk_sglang_merge.cu. + * Current production merge = MERGE_CUB_WARP (kAblMode_MergeCubWarp = 6). + * MERGE_PROD_DEFAULT (mode 5) is the legacy per-SPLITS dispatch kept for + * ablation comparison (not the current production merge). + * + * ablation_mode constants and merge variant constants are defined below in + * the K=30 ablation namespace section. + */ + + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + #include + #include + + #include "register.h" + + namespace { + + constexpr int kRadix = 256; + constexpr int kThreads = 1024; + constexpr int kWarpSize = 32; + constexpr int kNumSplits = 2; + constexpr size_t kMaxDynSmem = 96 * 1024; + + __device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + __device__ __forceinline__ uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + __device__ __forceinline__ void run_cumsum_256(int s_hist[2][kRadix + 128]) { + const int tx = threadIdx.x; + #pragma unroll 8 + for (int i = 0; i < 8; ++i) { + if (tx < kRadix) { + const int j = 1 << i; + const int k = i & 1; + int v = s_hist[k][tx]; + if (tx < kRadix - j) v += s_hist[k][tx + j]; + s_hist[k ^ 1][tx] = v; + } + __syncthreads(); + } + } + + __device__ __forceinline__ int warp_compact_slot(bool selected, int* s_counter) { + const uint32_t mask = __activemask(); + const uint32_t ballot = __ballot_sync(mask, selected); + const int lane = threadIdx.x & (kWarpSize - 1); + const int warp_count = __popc(ballot); + const int rank = __popc(ballot & ((1u << lane) - 1u)); + const int first = __ffs(mask) - 1; + int base = 0; + if (lane == first) { + base = (warp_count > 0) ? ::atomicAdd(s_counter, warp_count) : 0; + } + base = __shfl_sync(mask, base, first); + return selected ? (base + rank) : -1; + } + + // ============================================================================ + // Phase 1 ONLY: per-chunk radix select, writes unordered (score, idx) pairs + // into partial_scores/partial_indices. No barrier, no merge. + // ============================================================================ + __global__ __launch_bounds__(kThreads) + void TopK_Phase1_Only_Kernel( + const __nv_bfloat16* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ dense_kv_indices, + float* __restrict__ partial_scores, + int32_t* __restrict__ partial_indices, + const int topk_val, + const int reserved_bos, + const int reserved_eos) + { + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = row_end - row_start; + const int half = (row_len + 1) / 2; + const int ck_begin = (n == 0) ? 0 : half; + const int ck_end = (n == 0) ? half : row_len; + const int ck_len = ck_end - ck_begin; + + const __nv_bfloat16* chunk_in = score + row_start + ck_begin; + const int* idx_map = dense_kv_indices + row_start + ck_begin; + float* part_keys = partial_scores + (static_cast(b) * kNumSplits + n) * topk_val; + int32_t* part_idx = partial_indices + (static_cast(b) * kNumSplits + n) * topk_val; + + extern __shared__ char smem_raw[]; + alignas(128) __shared__ int s_hist_buf[2][kRadix + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + auto& s_hist = s_hist_buf[0]; + + float* s_remapped = reinterpret_cast(smem_raw); + uint8_t* s_bins = reinterpret_cast(s_remapped + ck_len); + + if (ck_len <= topk_val) { + for (int i = tx; i < topk_val; i += blockDim.x) { + if (i < ck_len) { + part_keys[i] = __bfloat162float(chunk_in[i]); + part_idx [i] = idx_map[i]; + } else { + part_keys[i] = -CUDART_INF_F; + part_idx [i] = -1; + } + } + return; + } + + if (tx < kRadix + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + __syncthreads(); + + for (int i = tx; i < ck_len; i += blockDim.x) { + const float v = __bfloat162float(chunk_in[i]); + const uint8_t bin = convert_to_uint8(v); + s_remapped[i] = v; + s_bins [i] = bin; + ::atomicAdd(&s_hist[bin], 1); + } + __syncthreads(); + run_cumsum_256(s_hist_buf); + + if (tx < kRadix && s_hist[tx] > topk_val && s_hist[tx + 1] <= topk_val) { + s_threshold_bin = tx; + s_last_remain = topk_val - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin = s_threshold_bin; + + if (tx < kRadix + 1) s_hist[tx] = 0; + __syncthreads(); + + const int num_iters = (ck_len + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + const bool in_range = (i < ck_len); + int bin = -1; + if (in_range) bin = s_bins[i]; + const bool take_above = in_range && (bin > threshold_bin); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + part_keys[slot] = s_remapped[i]; + part_idx [slot] = idx_map[i]; + } else if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[i]); + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + run_cumsum_256(s_hist_buf); + if (tx < kRadix && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) s_sub_threshold_bin = kRadix; + __syncthreads(); + const int sub_threshold_bin = s_sub_threshold_bin; + + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + const bool in_range = (i < ck_len); + int bin = -1; + if (in_range) bin = s_bins[i]; + int sub_bin = -1; + if (in_range && bin == threshold_bin) { + const uint32_t b32 = convert_to_uint32(s_remapped[i]); + sub_bin = (b32 >> 16) & 0xFF; + } + const bool take_sub = (sub_bin > sub_threshold_bin); + const int slot = warp_compact_slot(take_sub, &s_counter); + if (take_sub) { + part_keys[slot] = s_remapped[i]; + part_idx [slot] = idx_map[i]; + } else if (sub_bin == sub_threshold_bin) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) { + part_keys[topk_val - pos] = s_remapped[i]; + part_idx [topk_val - pos] = idx_map[i]; + } + } + } + } + + // ============================================================================ + // Phase 2 ONLY: read pre-populated (kNumSplits * K) candidates, radix-select + // top-K, write final indices. Grid = (batch,). + // ============================================================================ + __global__ __launch_bounds__(kThreads) + void TopK_Phase2_Only_Kernel( + const float* __restrict__ partial_scores, + const int32_t* __restrict__ partial_indices, + const int* __restrict__ sparse_kv_indptr, + int32_t* __restrict__ sparse_kv_indices, + const int topk_val, + const int reserved_bos) + { + const int b = blockIdx.x; + const int tx = threadIdx.x; + const int total = kNumSplits * topk_val; + + const float* keys_in = partial_scores + static_cast(b) * total; + const int32_t* idx_in = partial_indices + static_cast(b) * total; + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + + extern __shared__ char smem_raw[]; + alignas(128) __shared__ int s_hist_buf[2][kRadix + 128]; + alignas(128) __shared__ int s_counter; + alignas(128) __shared__ int s_threshold_bin; + alignas(128) __shared__ int s_sub_threshold_bin; + alignas(128) __shared__ int s_last_remain; + auto& s_hist = s_hist_buf[0]; + + float* s_scores = reinterpret_cast(smem_raw); + int32_t* s_indices = reinterpret_cast(s_scores + total); + + if ((total & 3) == 0) { + const float4* kv = reinterpret_cast(keys_in); + const int4* iv = reinterpret_cast (idx_in); + float4* sv = reinterpret_cast(s_scores); + int4* iiv = reinterpret_cast (s_indices); + const int total4 = total >> 2; + for (int i = tx; i < total4; i += blockDim.x) { + sv[i] = kv[i]; + iiv[i] = iv[i]; + } + } else { + for (int i = tx; i < total; i += blockDim.x) { + s_scores[i] = keys_in[i]; + s_indices[i] = idx_in[i]; + } + } + + if (tx < kRadix + 1) s_hist[tx] = 0; + if (tx == 0) { + s_counter = 0; + s_threshold_bin = -1; + s_sub_threshold_bin = -1; + s_last_remain = 0; + } + __syncthreads(); + + const int num_iters = (total + blockDim.x - 1) / blockDim.x; + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + if (i < total && s_indices[i] >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + ::atomicAdd(&s_hist[bin], 1); + } + } + __syncthreads(); + run_cumsum_256(s_hist_buf); + + const int valid_count = s_hist[0]; + if (valid_count <= topk_val) { + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + const bool take = (i < total) && (s_indices[i] >= 0); + const int slot = warp_compact_slot(take, &s_counter); + if (take && slot < topk_val) out_idx[slot] = s_indices[i]; + } + return; + } + + if (tx < kRadix && s_hist[tx] > topk_val && s_hist[tx + 1] <= topk_val) { + s_threshold_bin = tx; + s_last_remain = topk_val - s_hist[tx + 1]; + } + __syncthreads(); + const int threshold_bin_m = s_threshold_bin; + + if (tx < kRadix + 1) s_hist[tx] = 0; + __syncthreads(); + + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + bool in_valid = false; + int bin = -1; + uint32_t b32 = 0; + if (i < total) { + const int32_t ii = s_indices[i]; + if (ii >= 0) { + in_valid = true; + b32 = convert_to_uint32(s_scores[i]); + bin = (b32 >> 24) & 0xFF; + } + } + const bool take_above = in_valid && (bin > threshold_bin_m); + const int slot = warp_compact_slot(take_above, &s_counter); + if (take_above) { + out_idx[slot] = s_indices[i]; + } else if (in_valid && bin == threshold_bin_m) { + const int sub_bin = (b32 >> 16) & 0xFF; + ::atomicAdd(&s_hist[sub_bin], 1); + } + } + __syncthreads(); + + run_cumsum_256(s_hist_buf); + if (tx < kRadix && s_hist[tx] > s_last_remain + && s_hist[tx + 1] <= s_last_remain) { + s_sub_threshold_bin = tx; + s_last_remain = s_last_remain - s_hist[tx + 1]; + } + if (tx == 0 && s_sub_threshold_bin == -1) s_sub_threshold_bin = kRadix; + __syncthreads(); + const int sub_threshold_bin_m = s_sub_threshold_bin; + + for (int it = 0; it < num_iters; ++it) { + const int i = it * blockDim.x + tx; + bool in_thr = false; + int sub_bin = -1; + if (i < total) { + const int32_t ii = s_indices[i]; + if (ii >= 0) { + const uint32_t b32 = convert_to_uint32(s_scores[i]); + const int bin = (b32 >> 24) & 0xFF; + if (bin == threshold_bin_m) { + in_thr = true; + sub_bin = (b32 >> 16) & 0xFF; + } + } + } + const bool take = in_thr && (sub_bin > sub_threshold_bin_m); + const int slot = warp_compact_slot(take, &s_counter); + if (take) { + out_idx[slot] = s_indices[i]; + } else if (in_thr && sub_bin == sub_threshold_bin_m) { + const int pos = ::atomicAdd(&s_last_remain, -1); + if (pos > 0) out_idx[topk_val - pos] = s_indices[i]; + } + } + } + + template + void setup_smem_once() { + [[maybe_unused]] static const auto r = [] { + return ::cudaFuncSetAttribute(f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dyn); + }(); + TORCH_CHECK(r == cudaSuccess, "profile kernel setup failed: ", + ::cudaGetErrorString(r)); + } + + } // namespace + + #define CHECK_CUDA_T(x) TORCH_CHECK(x.is_cuda(), #x " must be CUDA") + + // ============================================================================ + // Host entry — Phase 1 only. + // ============================================================================ + void topk_adaptive_phase1_only( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& partial_scores, + at::Tensor& partial_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages) + { + CHECK_CUDA_T(x); CHECK_CUDA_T(dense_kv_indptr); + CHECK_CUDA_T(dense_kv_indices); CHECK_CUDA_T(partial_scores); CHECK_CUDA_T(partial_indices); + TORCH_CHECK(x.scalar_type() == at::ScalarType::BFloat16, + "profile kernels require bfloat16 input"); + + const int chunk_max = (static_cast(max_num_pages) + 1) / 2; + const size_t smem = static_cast(chunk_max) * sizeof(float) + + ((static_cast(chunk_max) + 15) & ~size_t(15)); + TORCH_CHECK(smem <= kMaxDynSmem, "phase1 smem too large"); + + setup_smem_once(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + dim3 grid(static_cast(eff_batch_size), + static_cast(kNumSplits)); + TopK_Phase1_Only_Kernel<<>>( + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()), + dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + partial_scores.data_ptr(), + partial_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos), + static_cast(reserved_eos)); + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "phase1 launch failed"); + } + + // ============================================================================ + // Host entry — Phase 2 only (expects partial_* pre-populated). + // ============================================================================ + void topk_adaptive_phase2_only( + const at::Tensor& partial_scores, + const at::Tensor& partial_indices, + const at::Tensor& sparse_kv_indptr, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos) + { + CHECK_CUDA_T(partial_scores); CHECK_CUDA_T(partial_indices); + CHECK_CUDA_T(sparse_kv_indptr); CHECK_CUDA_T(sparse_kv_indices); + + const size_t smem = static_cast(kNumSplits) * + static_cast(topk_val) * + (sizeof(float) + sizeof(int32_t)); + TORCH_CHECK(smem <= kMaxDynSmem, "phase2 smem too large"); + + setup_smem_once(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + dim3 grid(static_cast(eff_batch_size)); + TopK_Phase2_Only_Kernel<<>>( + partial_scores.data_ptr(), + partial_indices.data_ptr(), + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); + TORCH_CHECK(cudaGetLastError() == cudaSuccess, "phase2 launch failed"); + } + + // ============================================================================= + // K=30 phase ablation kernels and host entry. Bench-only fixture for + // `bench_ablation.py`. NOT a production code path. The production K=30 + // random-split parallel kernel and dispatcher live in topk_sglang_merge.cu. + // + // All kernels here are hardcoded to ScoreT=bf16, MAPPING_NONE, partition + // = CONTIGUOUS to keep the template instantiation count small. They share + // no code with the production path beyond the function declarations in + // register.h; helpers are duplicated below in the anonymous namespace. + // ============================================================================= + namespace { + + constexpr int kLocalK_Top30 = 32; + constexpr int kMaxFinalK_Top30 = 32; + constexpr int kPartContiguous = 1; // PART_CONTIGUOUS in topk_sglang_merge.cu + + // Per-split (NUM_THREADS, ITEMS_PER_THREAD) — must match the production + // configurations in topk_sglang_merge.cu (kCfg1..kCfg32) so the ablation + // numbers reflect the production launch parameters. + struct AblSplitCfg { int num_threads, items_per_thread; }; + constexpr AblSplitCfg kAblCfg1 = { 1024, 8 }; + constexpr AblSplitCfg kAblCfg2 = { 1024, 8 }; + constexpr AblSplitCfg kAblCfg4 = { 512, 8 }; + constexpr AblSplitCfg kAblCfg8 = { 256, 16 }; + constexpr AblSplitCfg kAblCfg16 = { 128, 16 }; + constexpr AblSplitCfg kAblCfg32 = { 64, 16 }; + + // --------------------------------------------------------------------------- + // Ablation mode constants (ablation_mode argument to + // topk_output_adaptive_workspace_ablation). + // --------------------------------------------------------------------------- + constexpr int kAblMode_FullAdaptive = 0; // full production path (reference) + constexpr int kAblMode_LocalWithWorkspace = 1; // local sort + workspace write, no merge + constexpr int kAblMode_LocalNoWorkspace = 2; // local sort only, no write, no merge + constexpr int kAblMode_WorkspaceWriteOnly = 3; // synthetic write to workspace + constexpr int kAblMode_AtomicOnly = 4; // atomic counter cost only + constexpr int kAblMode_MergeProdDefault = 5; // merge: legacy per-SPLITS dispatch + // (2-way for SPLITS=2, pairwise for SPLITS=4, + // k-way for SPLITS>=8). NOT current production. + constexpr int kAblMode_MergeCubWarp = 6; // merge: cub::WarpMergeSort — current production + constexpr int kAblMode_MergeCubBlock = 7; // merge: cub::BlockMergeSort benchmark + constexpr int kAblMode_MemsetOnly = 8; // counter memset cost only + constexpr int kAblMode_MergeManual2Way = 9; // merge: manual 2-way (requires split=2) + constexpr int kAblMode_MergePairwise4 = 10; // merge: pairwise tree (requires split=4) + constexpr int kAblMode_MergeKwayAll = 11; // merge: force k-way for all split counts + + // --------------------------------------------------------------------------- + // Merge variant indices for Ablation_MergeOnly_Kernel. + // --------------------------------------------------------------------------- + constexpr int MERGE_PROD_DEFAULT = 0; // legacy: 2-way(SPLITS=2)/pairwise(SPLITS=4)/k-way(>=8) + constexpr int MERGE_CUB_WARP = 1; // cub::WarpMergeSort — matches current production merge + // kMergeIPT=SPLITS; register pressure grows with SPLITS + constexpr int MERGE_CUB_BLOCK = 2; // cub::BlockMergeSort (benchmark; 64 threads) + constexpr int MERGE_MANUAL_2WAY = 3; // manual 2-way merge (requires SPLITS=2) + constexpr int MERGE_PAIRWISE_4 = 4; // pairwise tree (requires SPLITS=4) + constexpr int MERGE_KWAY = 5; // force k-way for all SPLITS (explicit baseline) + + template + __device__ __forceinline__ float vortex_to_float_p(T x); + template <> + __device__ __forceinline__ float vortex_to_float_p(float x) { return x; } + template <> + __device__ __forceinline__ float vortex_to_float_p<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + struct AblGreaterUint32 { + __device__ __forceinline__ bool operator()(uint32_t a, uint32_t b) const { + return a > b; + } + }; + + // k-way merge — same algorithm as topk_sglang_merge.cu's merge_sorted_kway. + template + __device__ __forceinline__ void abl_merge_sorted_kway( + const uint32_t* __restrict__ keys_in, + const int32_t* __restrict__ idx_in, + int32_t* __restrict__ out_idx, + int final_k) + { + const int lane = threadIdx.x & 31; + const bool is_my_list = (lane < SPLITS); + const uint32_t full = 0xFFFFFFFFu; + + int ptr = 0; + uint32_t cur_key = is_my_list ? keys_in[lane * LOCAL_K] : 0u; + int32_t cur_idx = is_my_list ? idx_in [lane * LOCAL_K] : -1; + + #pragma unroll + for (int t = 0; t < MAX_FINAL_K; ++t) { + uint32_t best_key = cur_key; + int best_lane = lane; + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + uint32_t okey = __shfl_xor_sync(full, best_key, offset); + int olane = __shfl_xor_sync(full, best_lane, offset); + bool take = (okey > best_key) || (okey == best_key && olane < best_lane); + best_key = take ? okey : best_key; + best_lane = take ? olane : best_lane; + } + int32_t win_idx = __shfl_sync(full, cur_idx, best_lane); + if (lane == 0 && t < final_k && win_idx >= 0) out_idx[t] = win_idx; + if (lane == best_lane) { + ++ptr; + if (is_my_list && ptr < LOCAL_K) { + cur_key = keys_in[lane * LOCAL_K + ptr]; + cur_idx = idx_in [lane * LOCAL_K + ptr]; + } else { + cur_key = 0u; + cur_idx = -1; + } + } + } + } + + __device__ __forceinline__ void abl_merge_2way_manual_lane0( + const uint32_t* __restrict__ l0_keys, const int32_t* __restrict__ l0_idx, int n0, + const uint32_t* __restrict__ l1_keys, const int32_t* __restrict__ l1_idx, int n1, + int32_t* __restrict__ out_idx, + int final_k) + { + if (threadIdx.x != 0) return; + int p0 = 0, p1 = 0; + for (int t = 0; t < final_k; ++t) { + const uint32_t k0 = (p0 < n0) ? l0_keys[p0] : 0u; + const uint32_t k1 = (p1 < n1) ? l1_keys[p1] : 0u; + if (k0 >= k1 && p0 < n0) { + out_idx[t] = l0_idx[p0]; ++p0; + } else if (p1 < n1) { + out_idx[t] = l1_idx[p1]; ++p1; + } else { + out_idx[t] = -1; + } + } + } + + __device__ __forceinline__ void abl_merge_pairwise_4( + const uint32_t* __restrict__ keys_in, + const int32_t* __restrict__ idx_in, + int32_t* __restrict__ out_idx, int final_k, + uint32_t* __restrict__ tmp01_keys, int32_t* __restrict__ tmp01_idx, + uint32_t* __restrict__ tmp23_keys, int32_t* __restrict__ tmp23_idx) + { + constexpr int LK = kLocalK_Top30; + const int lane = threadIdx.x & 31; + if (lane == 0) { + int p0 = 0, p1 = 0; + #pragma unroll + for (int t = 0; t < LK; ++t) { + const uint32_t k0 = (p0 < LK) ? keys_in[0 * LK + p0] : 0u; + const uint32_t k1 = (p1 < LK) ? keys_in[1 * LK + p1] : 0u; + if (k0 >= k1 && p0 < LK) { tmp01_keys[t] = k0; tmp01_idx[t] = idx_in[0*LK+p0]; ++p0; } + else if (p1 < LK) { tmp01_keys[t] = k1; tmp01_idx[t] = idx_in[1*LK+p1]; ++p1; } + else { tmp01_keys[t] = 0u; tmp01_idx[t] = -1; } + } + } else if (lane == 1) { + int p2 = 0, p3 = 0; + #pragma unroll + for (int t = 0; t < LK; ++t) { + const uint32_t k2 = (p2 < LK) ? keys_in[2 * LK + p2] : 0u; + const uint32_t k3 = (p3 < LK) ? keys_in[3 * LK + p3] : 0u; + if (k2 >= k3 && p2 < LK) { tmp23_keys[t] = k2; tmp23_idx[t] = idx_in[2*LK+p2]; ++p2; } + else if (p3 < LK) { tmp23_keys[t] = k3; tmp23_idx[t] = idx_in[3*LK+p3]; ++p3; } + else { tmp23_keys[t] = 0u; tmp23_idx[t] = -1; } + } + } + __syncwarp(); + if (lane == 0) { + int p0 = 0, p1 = 0; + for (int t = 0; t < final_k; ++t) { + const uint32_t k0 = (p0 < LK) ? tmp01_keys[p0] : 0u; + const uint32_t k1 = (p1 < LK) ? tmp23_keys[p1] : 0u; + if (k0 >= k1 && p0 < LK) { out_idx[t] = tmp01_idx[p0]; ++p0; } + else if (p1 < LK) { out_idx[t] = tmp23_idx[p1]; ++p1; } + else { out_idx[t] = -1; } + } + } + } + + // uint32 sort key for an fp32 value. + __device__ __forceinline__ uint32_t abl_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + // ---- Ablation kernels -------------------------------------------------------- + + template + __global__ __launch_bounds__(NUM_THREADS) + void Ablation_LocalOnly_Kernel( + const __nv_bfloat16* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ dense_kv_indices, + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices, + const int reserved_bos, + const int reserved_eos) + { + // Stage 1 + workspace write. No atomic. No merge. + using KeyT = uint32_t; using ValueT = int32_t; + using BlockSortT = cub::BlockRadixSort; + __shared__ typename BlockSortT::TempStorage sort_smem; + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + if (row_len <= 0) return; + const int group_begin = (row_len * n) / SPLITS; + const int group_end = (row_len * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + const __nv_bfloat16* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + KeyT keys[ITEMS_PER_THREAD]; ValueT values[ITEMS_PER_THREAD]; + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int local_rank = tx + k * NUM_THREADS; + if (local_rank < group_len) { + const int pos = group_begin + local_rank; + const float raw = vortex_to_float_p(row_scores[pos]); + keys [k] = abl_to_uint32(raw); + values[k] = row_idxmap[pos]; + } else { keys[k] = 0u; values[k] = -1; } + } + BlockSortT(sort_smem).SortDescending(keys, values); + __syncthreads(); + constexpr int LK = kLocalK_Top30; + const int64_t part_off = (static_cast(b) * SPLITS + n) * LK; + uint32_t* part_keys = partial_keys + part_off; + int32_t* part_idx = partial_indices + part_off; + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int rank = tx * ITEMS_PER_THREAD + k; + if (rank < LK) { part_keys[rank] = keys[k]; part_idx[rank] = values[k]; } + } + } + + template + __global__ __launch_bounds__(NUM_THREADS) + void Ablation_LocalNoWorkspace_Kernel( + const __nv_bfloat16* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ dense_kv_indices, + int32_t* __restrict__ scratch, + const int reserved_bos, + const int reserved_eos) + { + using KeyT = uint32_t; using ValueT = int32_t; + using BlockSortT = cub::BlockRadixSort; + __shared__ typename BlockSortT::TempStorage sort_smem; + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + if (row_len <= 0) return; + const int group_begin = (row_len * n) / SPLITS; + const int group_end = (row_len * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + const __nv_bfloat16* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + KeyT keys[ITEMS_PER_THREAD]; ValueT values[ITEMS_PER_THREAD]; + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int local_rank = tx + k * NUM_THREADS; + if (local_rank < group_len) { + const int pos = group_begin + local_rank; + const float raw = vortex_to_float_p(row_scores[pos]); + keys [k] = abl_to_uint32(raw); + values[k] = row_idxmap[pos]; + } else { keys[k] = 0u; values[k] = -1; } + } + BlockSortT(sort_smem).SortDescending(keys, values); + if (tx == 0) scratch[blockIdx.x * gridDim.y + blockIdx.y] = values[0]; + } + + template + __global__ __launch_bounds__(32) + void Ablation_WorkspaceWriteOnly_Kernel( + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices) + { + constexpr int LK = kLocalK_Top30; + const int b = blockIdx.x; + const int n = blockIdx.y; + const int lane = threadIdx.x; + const int64_t part_off = (static_cast(b) * SPLITS + n) * LK; + if (lane < LK) { + partial_keys [part_off + lane] = static_cast(b * 31 + n * 7 + lane); + partial_indices[part_off + lane] = b * 1009 + n * 17 + lane; + } + } + + template + __global__ __launch_bounds__(32) + void Ablation_AtomicOnly_Kernel( + int32_t* __restrict__ done_counter, + int32_t* __restrict__ scratch) + { + const int b = blockIdx.x; + const int tx = threadIdx.x; + __shared__ int s_is_last; + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + } + __syncthreads(); + if (s_is_last && tx == 0) scratch[b] = 1; + } + + // Correctness notes for Ablation_MergeOnly_Kernel: + // - MERGE_PROD_DEFAULT exactly mirrors topk_sglang_merge.cu Stage 2 tie + // preference: lower list index wins on equal key (k-way and pairwise); + // lane 0 favors list 0 on equal key (2-way manual). + // - CUB variants sort by uint32 key only. Tie-breaking for duplicate keys + // is implementation-defined and will NOT match production index order. + // Use unique keys for exact index comparison in correctness tests. + // - For throughput benchmarking, duplicate keys are acceptable since only + // latency is measured. + template + __global__ __launch_bounds__(64) + void Ablation_MergeOnly_Kernel( + const uint32_t* __restrict__ partial_keys, + const int32_t* __restrict__ partial_indices, + const int* __restrict__ sparse_kv_indptr, + int* __restrict__ sparse_kv_indices, + const int topk_val, + const int reserved_bos) + { + constexpr int LK = kLocalK_Top30; + constexpr int kCandidates = SPLITS * LK; + const int b = blockIdx.x; + const int tx = threadIdx.x; + const int64_t row_off = static_cast(b) * kCandidates; + const uint32_t* keys_in = partial_keys + row_off; + const int32_t* idx_in = partial_indices + row_off; + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + + if constexpr (MERGE_VARIANT == MERGE_PROD_DEFAULT) { + // Mirrors topk_sglang_merge.cu Stage 2 exactly: different strategy per SPLITS. + if constexpr (SPLITS == 2) { + if (tx < 32) abl_merge_2way_manual_lane0( + keys_in, idx_in, LK, + keys_in + LK, idx_in + LK, LK, + out_idx, topk_val); + } else if constexpr (SPLITS == 4) { + __shared__ uint32_t s_pd01k[LK]; __shared__ int32_t s_pd01i[LK]; + __shared__ uint32_t s_pd23k[LK]; __shared__ int32_t s_pd23i[LK]; + if (tx < 32) abl_merge_pairwise_4(keys_in, idx_in, out_idx, topk_val, + s_pd01k, s_pd01i, s_pd23k, s_pd23i); + } else { + if (tx < 32) abl_merge_sorted_kway( + keys_in, idx_in, out_idx, topk_val); + } + } else if constexpr (MERGE_VARIANT == MERGE_CUB_WARP) { + // kMergeIPT grows with SPLITS: SPLITS=16 → kMergeIPT=16, SPLITS=32 → 32. + // Large IPT increases register pressure and may cause spilling on sm_90+. + constexpr int kMergeIPT = (kCandidates + 31) / 32; + using WarpMergeT = cub::WarpMergeSort; + __shared__ typename WarpMergeT::TempStorage warp_merge; + if (tx < 32) { + uint32_t wkeys[kMergeIPT]; int32_t wvalues[kMergeIPT]; + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + wkeys [k] = (rank < kCandidates) ? keys_in[rank] : 0u; + wvalues[k] = (rank < kCandidates) ? idx_in [rank] : -1; + } + WarpMergeT(warp_merge).Sort(wkeys, wvalues, AblGreaterUint32{}); + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + if (rank < topk_val) out_idx[rank] = wvalues[k]; + } + } + } else if constexpr (MERGE_VARIANT == MERGE_CUB_BLOCK) { + constexpr int kBlockThreads = 64; + constexpr int kMergeIPT = (kCandidates + kBlockThreads - 1) / kBlockThreads; + using BlockMergeT = cub::BlockMergeSort; + __shared__ typename BlockMergeT::TempStorage block_merge; + if (tx < kBlockThreads) { + uint32_t wkeys[kMergeIPT]; int32_t wvalues[kMergeIPT]; + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + wkeys [k] = (rank < kCandidates) ? keys_in[rank] : 0u; + wvalues[k] = (rank < kCandidates) ? idx_in [rank] : -1; + } + BlockMergeT(block_merge).Sort(wkeys, wvalues, AblGreaterUint32{}); + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + if (rank < topk_val) out_idx[rank] = wvalues[k]; + } + } + } else if constexpr (MERGE_VARIANT == MERGE_MANUAL_2WAY) { + static_assert(SPLITS == 2, "manual_2way merge requires SPLITS=2"); + if (tx < 32) abl_merge_2way_manual_lane0( + keys_in, idx_in, LK, + keys_in + LK, idx_in + LK, LK, + out_idx, topk_val); + } else if constexpr (MERGE_VARIANT == MERGE_PAIRWISE_4) { + static_assert(SPLITS == 4, "pairwise_tree merge requires SPLITS=4"); + __shared__ uint32_t s_t01k[LK]; + __shared__ int32_t s_t01i[LK]; + __shared__ uint32_t s_t23k[LK]; + __shared__ int32_t s_t23i[LK]; + if (tx < 32) abl_merge_pairwise_4( + keys_in, idx_in, out_idx, topk_val, + s_t01k, s_t01i, s_t23k, s_t23i); + } else if constexpr (MERGE_VARIANT == MERGE_KWAY) { + // Force k-way for all SPLITS — explicit baseline to isolate k-way cost. + if (tx < 32) abl_merge_sorted_kway( + keys_in, idx_in, out_idx, topk_val); + } + } + + } // namespace + + void topk_output_adaptive_workspace_ablation( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& partial_keys, + at::Tensor& partial_indices, + at::Tensor& done_counter, + at::Tensor& scratch, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t ablation_mode, + const int64_t forced_splits) + { + TORCH_CHECK(x.scalar_type() == at::ScalarType::BFloat16, + "ablation kernels are bf16-only"); + TORCH_CHECK(topk_val > 0 && topk_val <= kMaxFinalK_Top30, + "ablation kernels are K<=32 only"); + + int split = forced_splits > 0 ? static_cast(forced_splits) : 8; + TORCH_CHECK(split == 1 || split == 2 || split == 4 || split == 8 || + split == 16 || split == 32, + "forced_splits must be {1,2,4,8,16,32}, got ", split); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + + // memset_only: just clear the counter. + if (ablation_mode == 8) { + if (split > 1) { + ::cudaMemsetAsync(done_counter.data_ptr(), 0, + sizeof(int32_t) * static_cast(eff_batch_size), + stream); + } + return; + } + // atomic_only needs the counter pre-cleared so each call sees a fresh state. + if (ablation_mode == 4 && split > 1) { + ::cudaMemsetAsync(done_counter.data_ptr(), 0, + sizeof(int32_t) * static_cast(eff_batch_size), + stream); + } + + uint32_t* part_keys_ptr = reinterpret_cast(partial_keys.data_ptr()); + int32_t* part_idx_ptr = partial_indices.data_ptr(); + int32_t* done_ptr = done_counter.data_ptr(); + int32_t* scratch_ptr = scratch.data_ptr(); + + dim3 grid_full (static_cast(eff_batch_size), + static_cast(split)); + dim3 grid_merge(static_cast(eff_batch_size), 1u); + + const __nv_bfloat16* x_ptr = + reinterpret_cast<__nv_bfloat16*>(x.data_ptr()); + + #define LAUNCH_ABL(KERNEL, GRID, NT, ...) \ + do { KERNEL<<>>(__VA_ARGS__); } while (0) + + switch (ablation_mode) { + case 1: { + switch (split) { + case 1: LAUNCH_ABL((Ablation_LocalOnly_Kernel<1, kAblCfg1.num_threads, kAblCfg1.items_per_thread>), grid_full, kAblCfg1.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 2: LAUNCH_ABL((Ablation_LocalOnly_Kernel<2, kAblCfg2.num_threads, kAblCfg2.items_per_thread>), grid_full, kAblCfg2.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 4: LAUNCH_ABL((Ablation_LocalOnly_Kernel<4, kAblCfg4.num_threads, kAblCfg4.items_per_thread>), grid_full, kAblCfg4.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 8: LAUNCH_ABL((Ablation_LocalOnly_Kernel<8, kAblCfg8.num_threads, kAblCfg8.items_per_thread>), grid_full, kAblCfg8.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 16: LAUNCH_ABL((Ablation_LocalOnly_Kernel<16, kAblCfg16.num_threads, kAblCfg16.items_per_thread>), grid_full, kAblCfg16.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 32: LAUNCH_ABL((Ablation_LocalOnly_Kernel<32, kAblCfg32.num_threads, kAblCfg32.items_per_thread>), grid_full, kAblCfg32.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), + part_keys_ptr, part_idx_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + } + break; + } + case 2: { + switch (split) { + case 1: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<1, kAblCfg1.num_threads, kAblCfg1.items_per_thread>), grid_full, kAblCfg1.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 2: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<2, kAblCfg2.num_threads, kAblCfg2.items_per_thread>), grid_full, kAblCfg2.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 4: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<4, kAblCfg4.num_threads, kAblCfg4.items_per_thread>), grid_full, kAblCfg4.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 8: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<8, kAblCfg8.num_threads, kAblCfg8.items_per_thread>), grid_full, kAblCfg8.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 16: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<16, kAblCfg16.num_threads, kAblCfg16.items_per_thread>), grid_full, kAblCfg16.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + case 32: LAUNCH_ABL((Ablation_LocalNoWorkspace_Kernel<32, kAblCfg32.num_threads, kAblCfg32.items_per_thread>), grid_full, kAblCfg32.num_threads, + x_ptr, dense_kv_indptr.data_ptr(), + dense_kv_indices.data_ptr(), scratch_ptr, + static_cast(reserved_bos), static_cast(reserved_eos)); break; + } + break; + } + case 3: { + switch (split) { + case 1: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<1>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 2: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<2>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 4: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<4>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 8: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<8>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 16: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<16>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + case 32: LAUNCH_ABL((Ablation_WorkspaceWriteOnly_Kernel<32>), grid_full, 32, part_keys_ptr, part_idx_ptr); break; + } + break; + } + case 4: { + switch (split) { + case 1: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<1>), grid_full, 32, done_ptr, scratch_ptr); break; + case 2: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<2>), grid_full, 32, done_ptr, scratch_ptr); break; + case 4: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<4>), grid_full, 32, done_ptr, scratch_ptr); break; + case 8: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<8>), grid_full, 32, done_ptr, scratch_ptr); break; + case 16: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<16>), grid_full, 32, done_ptr, scratch_ptr); break; + case 32: LAUNCH_ABL((Ablation_AtomicOnly_Kernel<32>), grid_full, 32, done_ptr, scratch_ptr); break; + } + break; + } + // kAblMode_MergeProdDefault=5, kAblMode_MergeCubWarp=6, kAblMode_MergeCubBlock=7 + // map to MERGE_PROD_DEFAULT=0, MERGE_CUB_WARP=1, MERGE_CUB_BLOCK=2 respectively. + case 5: case 6: case 7: { + const int variant = static_cast(ablation_mode - 5); + auto launch_merge = [&](auto split_const_var, int v) { + constexpr int S = decltype(split_const_var)::value; + switch (v) { + case MERGE_PROD_DEFAULT: + LAUNCH_ABL((Ablation_MergeOnly_Kernel), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); break; + case MERGE_CUB_WARP: + LAUNCH_ABL((Ablation_MergeOnly_Kernel), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); break; + case MERGE_CUB_BLOCK: + LAUNCH_ABL((Ablation_MergeOnly_Kernel), grid_merge, 64, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); break; + } + }; + switch (split) { + case 1: launch_merge(std::integral_constant{}, variant); break; + case 2: launch_merge(std::integral_constant{}, variant); break; + case 4: launch_merge(std::integral_constant{}, variant); break; + case 8: launch_merge(std::integral_constant{}, variant); break; + case 16: launch_merge(std::integral_constant{}, variant); break; + case 32: launch_merge(std::integral_constant{}, variant); break; + } + break; + } + case 9: { // kAblMode_MergeManual2Way + TORCH_CHECK(split == 2, "ablation 9 (merge_manual_2way) requires forced_splits=2"); + LAUNCH_ABL((Ablation_MergeOnly_Kernel<2, MERGE_MANUAL_2WAY>), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); + break; + } + case 10: { // kAblMode_MergePairwise4 + TORCH_CHECK(split == 4, "ablation 10 (merge_pairwise_4) requires forced_splits=4"); + LAUNCH_ABL((Ablation_MergeOnly_Kernel<4, MERGE_PAIRWISE_4>), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); + break; + } + case 11: { // kAblMode_MergeKwayAll: force k-way regardless of SPLITS + auto launch_kway = [&](auto split_tag) { + constexpr int S = decltype(split_tag)::value; + LAUNCH_ABL((Ablation_MergeOnly_Kernel), grid_merge, 32, + part_keys_ptr, part_idx_ptr, + sparse_kv_indptr.data_ptr(), + sparse_kv_indices.data_ptr(), + static_cast(topk_val), + static_cast(reserved_bos)); + }; + switch (split) { + case 1: launch_kway(std::integral_constant{}); break; + case 2: launch_kway(std::integral_constant{}); break; + case 4: launch_kway(std::integral_constant{}); break; + case 8: launch_kway(std::integral_constant{}); break; + case 16: launch_kway(std::integral_constant{}); break; + case 32: launch_kway(std::integral_constant{}); break; + } + break; + } + case 0: { + // full_parallel — re-enter the production workspace API with forced + // split + CONTIGUOUS partition. This makes the "0" mode useful as the + // 100% reference for the other ablations. + topk_output_adaptive_workspace( + x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, + sparse_kv_indices, partial_keys, partial_indices, done_counter, + eff_batch_size, topk_val, reserved_bos, reserved_eos, + max_num_pages, /*mapping_mode=*/0, /*mapping_power=*/0.0, + forced_splits, /*forced_partition=*/kPartContiguous, /*local_mode=*/0); + break; + } + default: + TORCH_CHECK(false, "unknown ablation_mode=", ablation_mode, + " (valid range: 0–11)"); + } + #undef LAUNCH_ABL + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "ablation launch failed: ", ::cudaGetErrorString(rc)); + } diff --git a/csrc/topk_sglang_merge.cu b/csrc/topk_sglang_merge.cu new file mode 100644 index 00000000..078b6d75 --- /dev/null +++ b/csrc/topk_sglang_merge.cu @@ -0,0 +1,1939 @@ +/** + * Vortex adaptive split TopK — random-split parallel K=30 path + fused + * fallback. Lives in topk_sglang_merge.cu (NOT a new file). + * + * Dispatch summary (host-side topk_output_adaptive_workspace): + * + * topk_val >= 1024 → immediate call to topk_output_sglang_fused. + * No workspace touched, no done_counter memset, no + * split kernel launched. Required for 32k → 2048. + * + * topk_val > 32 (and < 1024) → also forwards to fused (no specialised path). + * + * topk_val <= 32: + * forced_splits > 0 → use that split count (1, 2, 4, 8, 16, 32). + * forced_splits <= 0 → use heuristic pick_split_top30(). + * split == 1 (heuristic only) → fall back to fused. + * split == 1 (forced) → run the SPLITS=1 single-CTA path + * for benchmarking (one CUDA block sorts + * the whole row with cub::BlockRadixSort). + * + * Random split semantics: each split processes ONLY its slice of the row, + * not the whole row filtered by predicate, so total work = O(n) not O(n*S). + * + * group_begin = (n * split_id) / SPLITS + * group_end = (n * (split_id+1)) / SPLITS + * For each logical rank r in [group_begin, group_end), the physical + * page-table position is `permute(r, b_offset, n)`. For pow2 n we use + * the affine bijection + * pos = (r * a + b_offset) & (n - 1) + * with a = golden-ratio constant 2654435769 (odd → bijective mod 2^k). + * Per-row b_offset = b * 1013904223 + r0, where r0 is a fixed seed + * for reproducibility. For non-pow2 n we fall back to the contiguous + * mapping pos = r (the chunks then become consecutive slices). + * + * Local stage: cub::BlockRadixSort::SortDescending. Writes the top kLocalK=32 (key, idx) pairs to + * partial workspace per (row, split). + * + * Merge stage (last CTA, SPLITS > 1): cub::WarpMergeSort over SPLITS*32 + * candidates. kMergeIPT = SPLITS items per thread; the 32 warp lanes + * together hold all SPLITS*32 candidates. Each thread's SPLITS items are + * a contiguous descending-sorted slice of the workspace (one split per + * kMergeIPT items), so the WarpMergeSort precondition is satisfied. + * After the sort, threads write their items to out_idx at the correct + * global rank; only ranks < topk_val are written. + * Final top-topk_val global page IDs land in + * sparse_kv_indices[sparse_kv_indptr[b] + reserved_bos + rank]. + * + * done_counter is the external workspace; the host clears it with + * cudaMemsetAsync before each parallel-path launch. The fused-fallback + * branches do not touch it. + */ + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include + #include + #include + + #include + #include + #include + #include + + #include "register.h" + + namespace { + + constexpr int kLocalK_Top30 = 32; // local top-K per chunk + constexpr int kMaxFinalK_Top30 = 32; // accept topk_val up to this + constexpr int64_t kFusedFallbackTopK = 1024; // K >= this routes to fused + + // ============================================================================= + // Local-stage policy for the K=30 split kernel. + // + // BLOCK_FULL_SORT : per-CTA cub::BlockRadixSort over the whole split group, + // capped by the NT*IPT capacity ladder in kCfg* below. + // Original baseline kernel. + // + // SELECT32_SORT32 : per-CTA sglang-style 8-bit radix-select that emits + // exactly LOCAL_K=32 candidates without sorting the + // whole group, followed by a 32-element warp bitonic + // sort (cub::WarpMergeSort with IPT=1). Inner loops + // are strided over the group, so there is no NT*IPT + // ceiling and arbitrary chunk_len is supported. + // + // Both modes share the merge stage: each CTA writes a sorted local top-32 + // to partial workspace, the last CTA per row runs merge_cub_warp_topk. + // ============================================================================= + enum TopK30LocalMode : int { + LOCAL_BLOCK_FULL_SORT = 0, + LOCAL_SELECT32_SORT32 = 1, + }; + + // Affine permutation constants (LCG-style). a is odd → bijective mod 2^k. + constexpr uint32_t kPermuteA = 2654435769u; // golden ratio fractional bits + constexpr uint32_t kPermuteSeedB = 1013904223u; + constexpr uint32_t kPermuteOffset = 0x9E3779B9u; // additional offset + + // ---- bit-level helpers ------------------------------------------------------ + + // Sortable uint32 key for an fp32 value: ascending uint32 == ascending fp32. + __device__ __forceinline__ uint32_t convert_to_uint32(float x) { + uint32_t bits = __float_as_uint(x); + return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); + } + + template + __device__ __forceinline__ float vortex_to_float(T x); + template <> + __device__ __forceinline__ float vortex_to_float(float x) { return x; } + template <> + __device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { + return __bfloat162float(x); + } + + // Stage-1 8-bit bin used by topk_mapping.cuh's compute_stage1_bin. Defined + // here so the header pulls in cleanly, even though we don't otherwise use it. + __device__ __forceinline__ uint8_t convert_to_uint8(float x) { + __half h = __float2half_rn(x); + uint16_t bits = __half_as_ushort(h); + uint16_t key = (bits & 0x8000) ? static_cast(~bits) + : static_cast(bits | 0x8000); + return static_cast(key >> 8); + } + + #include "topk_mapping.cuh" + + // Affine permutation modulo 2^k, bijective when n is a power of two. + __device__ __forceinline__ int permute_pow2(uint32_t r, uint32_t b_off, uint32_t n_mask) { + return static_cast((r * kPermuteA + b_off) & n_mask); + } + + // True iff n is a strictly positive power of two. + __device__ __host__ __forceinline__ bool is_pow2(int n) { + return n > 0 && ((n & (n - 1)) == 0); + } + + // ============================================================================= + // Partition modes — control how a row's logical-rank space [0, n) is mapped + // to physical positions per split CTA. Goal: keep total work O(n) (no + // per-split full scan) while controlling memory access locality. + // + // AFFINE_RANDOM : pos = (a*r + b_off) & (n-1) [random gather] + // CONTIGUOUS : pos = group_begin + local_rank [coalesced] + // STRIDED : pos = split_id + local_rank * SPLITS [interleaved] + // TILE_RANDOM_128 : tile-permute then read TILE=128 contiguous positions + // within each tile. + // TILE_RANDOM_256 : same as 128 but with TILE=256. + // ============================================================================= + enum PartitionMode : int { + PART_AFFINE_RANDOM = 0, + PART_CONTIGUOUS = 1, + PART_STRIDED = 2, + PART_TILE_RANDOM_128 = 3, + PART_TILE_RANDOM_256 = 4, + }; + constexpr int kTileSize128 = 128; + constexpr int kTileSize256 = 256; + + // Tile-random: divide row into TILE-sized contiguous tiles, permute the tile + // IDs across the row using the affine bijection, and assign tiles_per_split + // = chunk_len / TILE tiles to each split. Within a tile, reads are + // contiguous → coalesced 128B / 256B sectors. + template + __device__ __forceinline__ int tile_random_pos( + int local_rank, int row_len, int split_id, + uint32_t b_off, uint32_t n_mask) + { + const int chunk_len = row_len / SPLITS; + if (chunk_len < TILE) { + // Fallback to affine when tiles don't fit. + const int group_begin = (row_len * split_id) / SPLITS; + const int r = group_begin + local_rank; + return permute_pow2(static_cast(r), b_off, n_mask); + } + const int tiles_per_split = chunk_len / TILE; + const int tile_in_split = local_rank / TILE; + const int offset_in_tile = local_rank & (TILE - 1); + const int global_tile_rank = split_id * tiles_per_split + tile_in_split; + const int tile_count = row_len / TILE; + const uint32_t tile_mask = static_cast(tile_count - 1); + const uint32_t tile_id = + (static_cast(global_tile_rank) * kPermuteA + b_off) & tile_mask; + return static_cast(tile_id) * TILE + offset_in_tile; + } + + template + __device__ __forceinline__ int compute_pos( + int local_rank, int row_len, int split_id, + uint32_t b_off, uint32_t n_mask) + { + if constexpr (PARTITION == PART_CONTIGUOUS) { + const int group_begin = (row_len * split_id) / SPLITS; + return group_begin + local_rank; + } else if constexpr (PARTITION == PART_STRIDED) { + // Each split owns lanes [split_id, split_id+SPLITS, split_id+2*SPLITS, ...]. + // Across all splits, the union covers every position in [0, row_len) + // exactly once when row_len is divisible by SPLITS. + return split_id + local_rank * SPLITS; + } else if constexpr (PARTITION == PART_TILE_RANDOM_128) { + return tile_random_pos( + local_rank, row_len, split_id, b_off, n_mask); + } else if constexpr (PARTITION == PART_TILE_RANDOM_256) { + return tile_random_pos( + local_rank, row_len, split_id, b_off, n_mask); + } else { + // AFFINE_RANDOM (default). + const int group_begin = (row_len * split_id) / SPLITS; + const int r = group_begin + local_rank; + return permute_pow2(static_cast(r), b_off, n_mask); + } + } + + // ============================================================================= + // Descending comparator for cub::WarpMergeSort (sorts largest key first). + // ============================================================================= + struct DescendingUint32 { + __device__ __forceinline__ bool operator()(uint32_t a, uint32_t b) const { + return a > b; + } + }; + + // ============================================================================= + // Single-warp CUB merge of SPLITS sorted top-LOCAL_K lists. + // + // Workspace layout (keys_in / idx_in): SPLITS * LOCAL_K elements, split-major + // with each split's LOCAL_K entries sorted descending. With LOCAL_K=32 and + // kMergeIPT = SPLITS (= SPLITS*32 / 32), thread tx holds exactly SPLITS + // consecutive items starting at tx*SPLITS — always a contiguous sorted slice + // within a single split's list. cub::WarpMergeSort precondition is satisfied. + // + // After the sort the global top-final_k indices are written to out_idx[0..final_k-1] + // by the threads that own those ranks; no lane conflicts. + // + // Register pressure: kMergeIPT = SPLITS. For SPLITS=32 each lane holds 32 + // key+value pairs (~128 B registers). Acceptable for sm_90+. + // ============================================================================= + template + __device__ __forceinline__ void merge_cub_warp_topk( + const uint32_t* __restrict__ keys_in, + const int32_t* __restrict__ idx_in, + int32_t* __restrict__ out_idx, + int final_k) + { + constexpr int kCandidates = SPLITS * LOCAL_K; + constexpr int kMergeIPT = (kCandidates + 31) / 32; + using WarpMergeT = cub::WarpMergeSort; + __shared__ typename WarpMergeT::TempStorage warp_merge_smem; + const int tx = threadIdx.x; + if (tx < 32) { + uint32_t wkeys[kMergeIPT]; + int32_t wvals[kMergeIPT]; + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + wkeys[k] = (rank < kCandidates) ? keys_in[rank] : 0u; + wvals[k] = (rank < kCandidates) ? idx_in [rank] : -1; + } + WarpMergeT(warp_merge_smem).Sort(wkeys, wvals, DescendingUint32{}); + #pragma unroll + for (int k = 0; k < kMergeIPT; ++k) { + const int rank = tx * kMergeIPT + k; + if (rank < final_k && wvals[k] >= 0) out_idx[rank] = wvals[k]; + } + } + } + + template + inline void setup_kernel_smem_once() { + [[maybe_unused]] + static const auto result = [] { + return ::cudaFuncSetAttribute( + f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); + }(); + TORCH_CHECK(result == cudaSuccess, + "topk_output_adaptive setup failed: ", + ::cudaGetErrorString(result)); + } + + // ============================================================================= + // K=30 random-split parallel kernel. + // + // Grid: (eff_batch_size, SPLITS). + // blockIdx.x = effective row id (0..eff_batch_size-1) + // blockIdx.y = split id (0..SPLITS-1) + // + // Stage 1 (every CTA): + // - Compute group_begin/group_end for this split. + // - For each local rank in [0, group_len), compute physical pos via + // permute_pow2 (or contiguous fallback for non-pow2 n). + // - Apply apply_transform_tmpl, build (uint32_key, int32_global_idx). + // - cub::BlockRadixSort.SortDescending. Top items at start of array. + // + // For SPLITS == 1 the kernel writes the top topk_val directly to + // sparse_kv_indices and returns — no merge. + // + // For SPLITS > 1, write the top kLocalK=32 (key, idx) pairs to the + // partial workspace at offset (b*SPLITS + n)*kLocalK. + // + // Last-CTA-wins barrier (SPLITS > 1): + // __threadfence (release) → atomicAdd → if old == SPLITS-1, last CTA → + // __threadfence (acquire) → __syncthreads. + // + // Stage 2 (last CTA, SPLITS > 1): + // - Load SPLITS*32 candidates into one warp / one block. + // - Sort descending by uint32 key. + // - Lanes 0..topk_val-1 (or threads) write their item to sparse_kv_indices. + // ============================================================================= + template + __global__ __launch_bounds__(NUM_THREADS) + void TopK30_RandomSplit_Parallel_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices, + int32_t* __restrict__ done_counter, + const int topk_val, + const int reserved_bos, + const int reserved_eos, + const float mapping_power) + { + using KeyT = uint32_t; + using ValueT = int32_t; + using BlockSortT = cub::BlockRadixSort; + + constexpr int kLocalK = kLocalK_Top30; + + __shared__ typename BlockSortT::TempStorage sort_smem; + __shared__ int s_is_last; + + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + + if (row_len <= 0) return; + + // --- Group boundaries (no overlap, no gaps across splits). --- + const int group_begin = (static_cast(row_len) * n) / SPLITS; + const int group_end = (static_cast(row_len) * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + + // --- Permutation parameters. --- + // For pow2 row_len, use affine bijection mod row_len. For non-pow2, fall + // back to identity (chunks become consecutive slices). + const bool row_is_pow2 = is_pow2(row_len); + const uint32_t n_mask = row_is_pow2 ? static_cast(row_len - 1) : 0u; + const uint32_t b_off = + static_cast(b) * kPermuteSeedB + kPermuteOffset; + + const ScoreT* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + + // ------------------------------------------------------------------ Stage 1 + KeyT keys[ITEMS_PER_THREAD]; + ValueT values[ITEMS_PER_THREAD]; + + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int local_rank = tx + k * NUM_THREADS; + if (local_rank < group_len) { + int pos; + if (row_is_pow2) { + pos = compute_pos( + local_rank, row_len, n, b_off, n_mask); + } else { + // Non-pow2 fallback: contiguous slice (also semantically valid for + // partition=CONTIGUOUS). + pos = group_begin + local_rank; + } + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + keys [k] = convert_to_uint32(remapped); + values[k] = row_idxmap[pos]; + } else { + keys [k] = 0u; + values[k] = -1; + } + } + + BlockSortT(sort_smem).SortDescending(keys, values); + __syncthreads(); + + // SPLITS == 1 special case: write final output directly. No atomic, no merge. + if constexpr (SPLITS == 1) { + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int rank = tx * ITEMS_PER_THREAD + k; + if (rank < topk_val) out_idx[rank] = values[k]; + } + return; + } + + // SPLITS > 1: write local top kLocalK to partial workspace. + const int64_t part_off = (static_cast(b) * SPLITS + n) * kLocalK; + uint32_t* part_keys = partial_keys + part_off; + int32_t* part_idx = partial_indices + part_off; + + #pragma unroll + for (int k = 0; k < ITEMS_PER_THREAD; ++k) { + const int rank = tx * ITEMS_PER_THREAD + k; + if (rank < kLocalK) { + part_keys[rank] = keys[k]; + part_idx [rank] = values[k]; + } + } + + // -------------------------------------------------------- Last-CTA barrier + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + // Self-reset: the last CTA clears its slot for the next launch. + // Eliminates the need for cudaMemsetAsync(done_counter) on the host — + // saves ~1-2 µs of CPU launch overhead per call. Same-stream kernels are + // sequenced, so the next launch sees done_counter[b] == 0. + if (s_is_last) done_counter[b] = 0; + } + __syncthreads(); + if (s_is_last == 0) return; + // Acquire fence: ensure the merging CTA observes other CTAs' partial writes. + __threadfence(); + __syncthreads(); + + // ------------------------------------------------------------------ Stage 2 + // cub::WarpMergeSort over all SPLITS*kLocalK candidates (warp 0 only). + // kMergeIPT = SPLITS items per lane; each lane's items are a contiguous + // sorted slice within a single split's list, satisfying WarpMergeSort's + // pre-sorted-per-thread precondition. + const int64_t row_off = static_cast(b) * SPLITS * kLocalK; + const uint32_t* keys_in = partial_keys + row_off; + const int32_t* idx_in = partial_indices + row_off; + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + + merge_cub_warp_topk( + keys_in, idx_in, out_idx, topk_val); + } + + // ============================================================================= + // K=30 SELECT32_SORT32 local-stage kernel (Plan C). + // + // Grid: (eff_batch_size, SPLITS). + // blockIdx.x = effective row id, blockIdx.y = split id. + // + // Per-CTA pipeline: + // Pass 1 - top-byte (bits [31:24]) histogram + suffix-sum-descending, + // find the threshold bin where cumulative count crosses + // LOCAL_K=32 (unique by monotonicity). + // Pass 2 - re-scan the split group: items strictly above the threshold + // bin go straight into the candidate buffer (count is + // guaranteed < LOCAL_K). Items at the threshold bin contribute + // to a sub-bin (bits [23:16]) histogram. + // Pass 3 - find the sub-threshold bin in the sub-hist, then re-scan the + // threshold bin and gather (sub > sub_threshold) and + // (sub == sub_threshold) candidates into the remaining slots. + // Stage D - 32-lane warp bitonic sort over the LOCAL_K candidates, + // descending by uint32 key. Implemented via cub::WarpMergeSort + // with IPT=1 (sort precondition is trivially satisfied). + // + // SPLITS == 1: write top topk_val directly to sparse_kv_indices, no + // workspace, no atomic, no merge. + // SPLITS > 1: write sorted local top-LOCAL_K to partial workspace, the + // last CTA per row runs merge_cub_warp_topk. + // + // No cub::BlockRadixSort smem and no NT*IPT capacity ceiling. Each pass + // is a strided loop over [0, group_len) so the kernel handles any + // chunk length the splits produce, including the SPLITS=1 / 32k case. + // ============================================================================= + template + __global__ __launch_bounds__(NUM_THREADS) + void TopK30_RandomSplit_Select32_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices, + int32_t* __restrict__ done_counter, + const int topk_val, + const int reserved_bos, + const int reserved_eos, + const float mapping_power) + { + constexpr int LOCAL_K = kLocalK_Top30; + constexpr int kRadix = 256; + + alignas(128) __shared__ int s_hist_buf[2][kRadix + 128]; + __shared__ int s_above_count; // count strictly above threshold_bin (pass 2) + __shared__ int s_thresh_above_count; // count (bin==t && sub>sub_t) (pass 3) + __shared__ int s_thresh_at_count; // count (bin==t && sub==sub_t) (pass 3, capped) + __shared__ int s_threshold_bin; + __shared__ int s_last_remain; + __shared__ int s_sub_threshold_bin; + __shared__ int s_sub_last_remain; + __shared__ int s_strictly_above_sub; + __shared__ uint32_t s_top_keys[LOCAL_K]; + __shared__ int32_t s_top_idx [LOCAL_K]; + __shared__ int s_is_last; + + using LocalSortT = cub::WarpMergeSort; + __shared__ typename LocalSortT::TempStorage local_sort_smem; + + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + + const int group_begin = (static_cast(row_len) * n) / SPLITS; + const int group_end = (static_cast(row_len) * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + + const bool row_is_pow2 = is_pow2(row_len); + const uint32_t n_mask = row_is_pow2 ? static_cast(row_len - 1) : 0u; + const uint32_t b_off = + static_cast(b) * kPermuteSeedB + kPermuteOffset; + const ScoreT* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + + // ---- Init shared state. Strided over the +128 padding so any NT works. ---- + for (int i = tx; i < kRadix + 128; i += NUM_THREADS) { + s_hist_buf[0][i] = 0; + s_hist_buf[1][i] = 0; + } + if (tx == 0) { + s_above_count = 0; + s_thresh_above_count = 0; + s_thresh_at_count = 0; + s_threshold_bin = -1; + s_last_remain = 0; + s_sub_threshold_bin = -1; + s_sub_last_remain = 0; + s_strictly_above_sub = 0; + s_is_last = 0; + } + if (tx < LOCAL_K) { + s_top_keys[tx] = 0u; + s_top_idx [tx] = -1; + } + __syncthreads(); + + // Empty-row early exit. SPLITS>1 must still participate in the merge + // barrier so the last-CTA flag fires; padding is already (0u, -1). + if (row_len <= 0 || group_len <= 0) { + if constexpr (SPLITS > 1) { + const int64_t part_off = + (static_cast(b) * SPLITS + n) * LOCAL_K; + if (tx < LOCAL_K) { + partial_keys [part_off + tx] = 0u; + partial_indices[part_off + tx] = -1; + } + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + if (s_is_last) done_counter[b] = 0; // self-reset for next launch + } + __syncthreads(); + if (s_is_last == 0) return; + __threadfence(); + __syncthreads(); + const int64_t row_off = static_cast(b) * SPLITS * LOCAL_K; + merge_cub_warp_topk( + partial_keys + row_off, partial_indices + row_off, + sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos, + topk_val); + } + return; + } + + // Strided suffix-sum-descending over s_hist_buf, ping-pong; result in [0]. + // Works for any NUM_THREADS (uses a strided inner loop over kRadix). + auto run_cumsum_strided = [&]() { + #pragma unroll + for (int i = 0; i < 8; ++i) { + const int j = 1 << i; + const int k = i & 1; + for (int idx = tx; idx < kRadix; idx += NUM_THREADS) { + int v = s_hist_buf[k][idx]; + if (idx + j < kRadix) v += s_hist_buf[k][idx + j]; + s_hist_buf[k ^ 1][idx] = v; + } + __syncthreads(); + } + }; + + // ============================================================ + // Pass 1: top-byte histogram. + // ============================================================ + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) { + pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + } else { + pos = group_begin + local_rank; + } + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + ::atomicAdd(&s_hist_buf[0][bin], 1); + } + __syncthreads(); + + run_cumsum_strided(); + // s_hist_buf[0][bin] = count of items with key>>24 >= bin. + + const int total_items = s_hist_buf[0][0]; + + // Find threshold bin: the unique bin t where total_at_or_above[t] >= K + // and strictly_above[t] < K. Strided so any NT works. + for (int bin = tx; bin < kRadix; bin += NUM_THREADS) { + const int total_at_or_above = s_hist_buf[0][bin]; + const int strictly_above = (bin + 1 < kRadix) ? s_hist_buf[0][bin + 1] : 0; + if (total_at_or_above >= LOCAL_K && strictly_above < LOCAL_K) { + s_threshold_bin = bin; + s_last_remain = LOCAL_K - strictly_above; + } + } + __syncthreads(); + + if (total_items <= LOCAL_K) { + // Few-elements path: collect everything in arbitrary order, pad rest. + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int slot = ::atomicAdd(&s_above_count, 1); + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } + __syncthreads(); + // s_top_keys/idx already pre-padded to (0u, -1) at init. + } else { + const int threshold_bin = s_threshold_bin; + + // Reset both hist buffers for the sub-bin pass. + for (int i = tx; i < kRadix + 128; i += NUM_THREADS) { + s_hist_buf[0][i] = 0; + s_hist_buf[1][i] = 0; + } + __syncthreads(); + + // ============================================================ + // Pass 2: gather strictly-above-threshold items + build sub-hist. + // ============================================================ + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + if (bin > threshold_bin) { + const int slot = ::atomicAdd(&s_above_count, 1); + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } else if (bin == threshold_bin) { + const int sub_bin = static_cast((key >> 16) & 0xFF); + ::atomicAdd(&s_hist_buf[0][sub_bin], 1); + } + } + __syncthreads(); + + run_cumsum_strided(); + + const int last_remain = s_last_remain; + for (int bin = tx; bin < kRadix; bin += NUM_THREADS) { + const int total_at_or_above = s_hist_buf[0][bin]; + const int strictly_above = (bin + 1 < kRadix) ? s_hist_buf[0][bin + 1] : 0; + if (total_at_or_above >= last_remain && strictly_above < last_remain) { + s_sub_threshold_bin = bin; + s_sub_last_remain = last_remain - strictly_above; + s_strictly_above_sub = strictly_above; + } + } + __syncthreads(); + + const int sub_threshold_bin = s_sub_threshold_bin; + const int sub_last_remain = s_sub_last_remain; + const int strictly_above_sub_bn = s_strictly_above_sub; + const int above_base = s_above_count; // = strictly_above_threshold + + // ============================================================ + // Pass 3: gather threshold-bin sub-above + sub-at items. + // ============================================================ + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + if (bin == threshold_bin) { + const int sub_bin = static_cast((key >> 16) & 0xFF); + if (sub_bin > sub_threshold_bin) { + const int rel = ::atomicAdd(&s_thresh_above_count, 1); + const int slot = above_base + rel; + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } else if (sub_bin == sub_threshold_bin) { + const int rel = ::atomicAdd(&s_thresh_at_count, 1); + if (rel < sub_last_remain) { + const int slot = above_base + strictly_above_sub_bn + rel; + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } + } + } + } + __syncthreads(); + } + + // ============================================================ + // Stage D: 32-lane warp bitonic sort over the LOCAL_K candidates. + // cub::WarpMergeSort with IPT=1 has trivial pre-sorted-per-thread + // precondition (each lane owns exactly 1 item). + // ============================================================ + if (tx < 32) { + uint32_t kk[1] = { s_top_keys[tx] }; + int32_t vv[1] = { s_top_idx [tx] }; + LocalSortT(local_sort_smem).Sort(kk, vv, DescendingUint32{}); + s_top_keys[tx] = kk[0]; + s_top_idx [tx] = vv[0]; + } + __syncthreads(); + + // ============================================================ + // SPLITS == 1: direct write to sparse_kv_indices. + // ============================================================ + if constexpr (SPLITS == 1) { + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + if (tx < topk_val) out_idx[tx] = s_top_idx[tx]; + return; + } + + // ============================================================ + // SPLITS > 1: write workspace, last-CTA-wins barrier, merge. + // ============================================================ + const int64_t part_off = (static_cast(b) * SPLITS + n) * LOCAL_K; + if (tx < LOCAL_K) { + partial_keys [part_off + tx] = s_top_keys[tx]; + partial_indices[part_off + tx] = s_top_idx [tx]; + } + + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + if (s_is_last) done_counter[b] = 0; // self-reset for next launch + } + __syncthreads(); + if (s_is_last == 0) return; + __threadfence(); + __syncthreads(); + + const int64_t row_off = static_cast(b) * SPLITS * LOCAL_K; + merge_cub_warp_topk( + partial_keys + row_off, partial_indices + row_off, + sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos, + topk_val); + } + + // ============================================================================= + // Per-split (NUM_THREADS, ITEMS_PER_THREAD) configuration. + // + // NUM_THREADS * ITEMS_PER_THREAD must cover the per-split chunk length + // (= ceil(max_num_pages / SPLITS)). Picked once per SPLITS rather than per + // (SPLITS, max_num_pages) to keep the template instantiation count + // manageable. + // + // cub::BlockRadixSort uses ~NT*IPT*sizeof(KeyT) bytes of static shared + // memory; ptxas rejects kernels exceeding 48 KB static smem on sm_100a + // without opt-in (which static smem can't easily use). Using uint32 keys + // + 8-byte (key,value) effective footprint, we keep NT*IPT*4 <= ~32 KB → + // NT*IPT <= 8192. Coverage: + // + // | chunk_max | covers max_num_pages + // ------------------------------------ + // 1 | 8192 | 8192 + // 2 | 8192 | 16384 + // 4 | 4096 | 16384 + // 8 | 4096 | 32768 + // 16 | 2048 | 32768 + // 32 | 1024 | 32768 + // + // Configs above the coverage row fall back to the fused single-CTA kernel + // in the dispatcher (capacity check below). + // ============================================================================= + struct SplitCfg { int splits, num_threads, items_per_thread; }; + + constexpr SplitCfg kCfg1 = { 1, 1024, 8 }; // cap 8192 + constexpr SplitCfg kCfg2 = { 2, 1024, 8 }; // cap 8192 + constexpr SplitCfg kCfg4 = { 4, 512, 8 }; // cap 4096 + constexpr SplitCfg kCfg8 = { 8, 256, 16 }; // cap 4096 + constexpr SplitCfg kCfg16 = {16, 128, 16 }; // cap 2048 + constexpr SplitCfg kCfg32 = {32, 64, 16 }; // cap 1024 + + // Returns the per-split capacity (NT*IPT) for a given split count, or 0 if + // the split is not supported. + inline int split_capacity(int split) { + switch (split) { + case 1: return kCfg1.num_threads * kCfg1.items_per_thread; + case 2: return kCfg2.num_threads * kCfg2.items_per_thread; + case 4: return kCfg4.num_threads * kCfg4.items_per_thread; + case 8: return kCfg8.num_threads * kCfg8.items_per_thread; + case 16: return kCfg16.num_threads * kCfg16.items_per_thread; + case 32: return kCfg32.num_threads * kCfg32.items_per_thread; + default: return 0; + } + } + + inline int next_supported_split(int required) { + if (required <= 1) return 1; + if (required <= 2) return 2; + if (required <= 4) return 4; + if (required <= 8) return 8; + if (required <= 16) return 16; + return 32; + } + + // ============================================================================= + // Per-split NUM_THREADS for the SELECT32_SORT32 kernel. + // + // No NT*IPT capacity ladder: the kernel scans the split group with strided + // loops, so any group_len works at any NT. Picked here only to balance + // memory throughput vs occupancy. NT=128 is fine for high splits because + // chunk_len shrinks proportionally (max_pages=32k / SPLITS=32 -> 1024). + // ============================================================================= + struct SelectCfg { int splits, num_threads; }; + constexpr SelectCfg kSelCfg1 = { 1, 1024 }; + constexpr SelectCfg kSelCfg2 = { 2, 1024 }; + constexpr SelectCfg kSelCfg4 = { 4, 512 }; + constexpr SelectCfg kSelCfg8 = { 8, 256 }; + constexpr SelectCfg kSelCfg16 = {16, 128 }; + constexpr SelectCfg kSelCfg32 = {32, 128 }; + + // SM-cover policy: pick the smallest supported split such that + // total_ctas = eff_batch_size * split >= sm_count. This prioritises + // filling the device. Capacity / merge cost are NOT considered here — + // the dispatcher's capacity check below catches infeasible configs. + inline int choose_split_k30_b200(int64_t eff_bs, int64_t /*max_pages*/, + int forced, int sm_count) + { + if (forced > 0) return forced; + constexpr int kSMCoverDefault = 180; // B200 multiprocessorCount + const int target_blocks = sm_count > 0 ? sm_count : kSMCoverDefault; + const int required = static_cast( + (target_blocks + eff_bs - 1) / eff_bs); + return next_supported_split(required); + } + + // Default partition mode picker. The B200 sweep at K=30 shows CONTIGUOUS + // dominates affine and tile-random by 10-15% at high splits (8/16/32) and + // is within noise at low splits — coalesced loads are the bottleneck once + // each split has more than a handful of threads. Random vs contiguous is + // correctness-equivalent here (each split's local top-32 is merged via + // CUB WarpMergeSort into the global top-30, regardless of partition layout). + // Override via forced_partition for ablation. + inline int default_partition(int /*split*/, int64_t /*max_num_pages*/) { + return PART_CONTIGUOUS; + } + + // Heuristic split picker for K<=32. + // + // ALWAYS returns an adaptive split count in {1,2,4,8,16,32}. Never falls + // back to fused — for K=30 the dispatcher in topk_output_adaptive_workspace + // is required to stay on the adaptive path. split=1 means "single-CTA + // adaptive kernel", NOT "use fused sglang baseline". + // + // Table from B200 sweep (benchmarks/bench_topk_setting_sweep.py, + // SELECT32_SORT32 local mode, CONTIGUOUS partition, CUB WarpMergeSort merge): + // + // max_pages <= 32768 : split=1 wins or ties at every B in {1..16}; + // e.g. 4k/B=4 -> 17.2us @s=1 vs 23.4us @s=2. + // max_pages == 65536 : split=4 beats split=1 by ~18-19% within adaptive + // (s=1 41.8us vs s=4 33.7us); 4 CTAs * 16k chunk + // keeps the per-CTA radix select small enough that + // the merge cost is amortised by the parallel scan. + // + // forced_splits overrides this for benchmarking. + inline int pick_split_top30(int64_t /*eff_bs*/, int64_t max_pages) { + if (max_pages > 32768) return 4; + return 1; + } + + // ============================================================================= + // Mid-K (K in {64, 128, 256, 512}) generalized SELECTK_SORTK kernel. + // + // Same structure as TopK30_RandomSplit_Select32_Kernel, with LOCAL_K + // templated up to 512 and the local sort + final merge replaced with + // cub::BlockMergeSort variants sized by LOCAL_K and SPLITS*LOCAL_K + // respectively. + // + // Per-CTA pipeline (mirrors K=30 path; only sizes change): + // Pass 1 — top-byte (bits [31:24]) histogram + suffix-sum-descending, + // find threshold bin where cumulative count crosses LOCAL_K. + // Pass 2 — strictly-above-threshold goes straight to candidate buffer; + // equal-to-threshold contributes to sub-bin (bits [23:16]) histogram. + // Pass 3 — sub-threshold then sub-equal candidates. + // Sort — cub::BlockMergeSort over LOCAL_K candidates with NT_SORT=128 + // and IPT_SORT = ceil(LOCAL_K, 128) / 128 (LOCAL_K=64 padded to 128). + // + // Final merge (last CTA, SPLITS > 1): + // cub::BlockMergeSort over SPLITS * LOCAL_K candidates. Capped at 4096 + // candidates total (NT=256, IPT=16) for register pressure. + // + // Capacity policy (max SPLITS per LOCAL_K, candidates capped at 4096): + // LOCAL_K=64 -> SPLITS in {1, 2, 4, 8, 16, 32} (max C=2048) + // LOCAL_K=128 -> SPLITS in {1, 2, 4, 8, 16, 32} (max C=4096) + // LOCAL_K=256 -> SPLITS in {1, 2, 4, 8, 16} (max C=4096) + // LOCAL_K=512 -> SPLITS in {1, 2, 4, 8} (max C=4096) + // + // For NT_SORT=128 we need LOCAL_K to be a multiple of 128 (the slot + // buffer is padded with (key=0, idx=-1) sentinels otherwise; a few + // threads sort dummy items, but the descending sort drops them past + // LOCAL_K and they are never read). + // ============================================================================= + // SortNTConfig sizes the slot buffer + IPT for the local-stage sort. + // We sort SLOTS_PADDED items with NT threads, IPT_SORT items per thread. + // SLOTS_PADDED = max(LOCAL_K, NT) so all NT threads have at least one slot. + template + struct SortNTConfig { + static constexpr int SLOTS_PADDED = + (LOCAL_K >= NT) ? ((LOCAL_K + NT - 1) / NT) * NT : NT; + static constexpr int NT_SORT = NT; + static constexpr int IPT_SORT = SLOTS_PADDED / NT; + }; + + template + struct MergeNTConfig { + // Final merge runs in the same kernel block (last CTA), so NT_MERGE must + // equal the kernel's NUM_THREADS or BlockMergeSort would deadlock. + static constexpr int PADDED = (CANDIDATES + NT - 1) / NT * NT; + static constexpr int NT_MERGE = NT; + static constexpr int IPT_MERGE = PADDED / NT; + }; + + template + __device__ __forceinline__ void merge_block_sort_topk_midk( + const uint32_t* __restrict__ keys_in, + const int32_t* __restrict__ idx_in, + int32_t* __restrict__ out_idx, + int final_k) + { + constexpr int kCandidates = SPLITS * LOCAL_K; + constexpr int NT = MergeNTConfig::NT_MERGE; + constexpr int IPT = MergeNTConfig::IPT_MERGE; + using BlockSortT = cub::BlockMergeSort; + __shared__ typename BlockSortT::TempStorage block_merge_smem; + + const int tx = threadIdx.x; + uint32_t bkeys[IPT]; + int32_t bvals[IPT]; + #pragma unroll + for (int k = 0; k < IPT; ++k) { + const int rank = tx * IPT + k; + bkeys[k] = (rank < kCandidates) ? keys_in[rank] : 0u; + bvals[k] = (rank < kCandidates) ? idx_in [rank] : -1; + } + BlockSortT(block_merge_smem).Sort(bkeys, bvals, DescendingUint32{}); + #pragma unroll + for (int k = 0; k < IPT; ++k) { + const int rank = tx * IPT + k; + if (rank < final_k && bvals[k] >= 0) out_idx[rank] = bvals[k]; + } + } + + template + __global__ __launch_bounds__(NUM_THREADS) + void TopKMidK_RandomSplit_SelectK_Kernel( + const ScoreT* __restrict__ score, + const int* __restrict__ dense_kv_indptr, + const int* __restrict__ sparse_kv_indptr, + const int* __restrict__ dense_kv_indices, + int* __restrict__ sparse_kv_indices, + uint32_t* __restrict__ partial_keys, + int32_t* __restrict__ partial_indices, + int32_t* __restrict__ done_counter, + const int topk_val, + const int reserved_bos, + const int reserved_eos, + const float mapping_power) + { + constexpr int kRadix = 256; + constexpr int SLOTS_PADDED = SortNTConfig::SLOTS_PADDED; + constexpr int NT_SORT = SortNTConfig::NT_SORT; + constexpr int IPT_SORT = SortNTConfig::IPT_SORT; + // cub::BlockMergeSort calls __syncthreads() internally — every thread in + // the block must enter the sort branch, so NT_SORT must equal NUM_THREADS. + static_assert(NT_SORT == NUM_THREADS, + "NT_SORT must equal NUM_THREADS or BlockMergeSort deadlocks"); + + alignas(128) __shared__ int s_hist_buf[2][kRadix + 128]; + __shared__ int s_above_count; + __shared__ int s_thresh_above_count; + __shared__ int s_thresh_at_count; + __shared__ int s_threshold_bin; + __shared__ int s_last_remain; + __shared__ int s_sub_threshold_bin; + __shared__ int s_sub_last_remain; + __shared__ int s_strictly_above_sub; + __shared__ uint32_t s_top_keys[SLOTS_PADDED]; + __shared__ int32_t s_top_idx [SLOTS_PADDED]; + __shared__ int s_is_last; + + using LocalSortT = cub::BlockMergeSort; + __shared__ typename LocalSortT::TempStorage local_sort_smem; + + const int b = blockIdx.x; + const int n = blockIdx.y; + const int tx = threadIdx.x; + + const int row_start = dense_kv_indptr[b] + reserved_bos; + const int row_end = dense_kv_indptr[b + 1] - reserved_eos; + const int row_len = max(0, row_end - row_start); + + const int group_begin = (static_cast(row_len) * n) / SPLITS; + const int group_end = (static_cast(row_len) * (n + 1)) / SPLITS; + const int group_len = group_end - group_begin; + + const bool row_is_pow2 = is_pow2(row_len); + const uint32_t n_mask = row_is_pow2 ? static_cast(row_len - 1) : 0u; + const uint32_t b_off = + static_cast(b) * kPermuteSeedB + kPermuteOffset; + const ScoreT* row_scores = score + row_start; + const int* row_idxmap = dense_kv_indices + row_start; + + // ---- Init shared state. Strided over padding so any NT works. ---- + for (int i = tx; i < kRadix + 128; i += NUM_THREADS) { + s_hist_buf[0][i] = 0; + s_hist_buf[1][i] = 0; + } + if (tx == 0) { + s_above_count = 0; + s_thresh_above_count = 0; + s_thresh_at_count = 0; + s_threshold_bin = -1; + s_last_remain = 0; + s_sub_threshold_bin = -1; + s_sub_last_remain = 0; + s_strictly_above_sub = 0; + s_is_last = 0; + } + for (int i = tx; i < SLOTS_PADDED; i += NUM_THREADS) { + s_top_keys[i] = 0u; + s_top_idx [i] = -1; + } + __syncthreads(); + + // Empty-row early exit (preserves merge barrier for SPLITS>1). + if (row_len <= 0 || group_len <= 0) { + if constexpr (SPLITS > 1) { + const int64_t part_off = + (static_cast(b) * SPLITS + n) * LOCAL_K; + for (int i = tx; i < LOCAL_K; i += NUM_THREADS) { + partial_keys [part_off + i] = 0u; + partial_indices[part_off + i] = -1; + } + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + if (s_is_last) done_counter[b] = 0; // self-reset for next launch + } + __syncthreads(); + if (s_is_last == 0) return; + __threadfence(); + __syncthreads(); + const int64_t row_off = static_cast(b) * SPLITS * LOCAL_K; + merge_block_sort_topk_midk( + partial_keys + row_off, partial_indices + row_off, + sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos, + topk_val); + } + return; + } + + auto run_cumsum_strided = [&]() { + #pragma unroll + for (int i = 0; i < 8; ++i) { + const int j = 1 << i; + const int k = i & 1; + for (int idx = tx; idx < kRadix; idx += NUM_THREADS) { + int v = s_hist_buf[k][idx]; + if (idx + j < kRadix) v += s_hist_buf[k][idx + j]; + s_hist_buf[k ^ 1][idx] = v; + } + __syncthreads(); + } + }; + + // ============== Pass 1: top-byte histogram. ============== + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) { + pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + } else { + pos = group_begin + local_rank; + } + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + ::atomicAdd(&s_hist_buf[0][bin], 1); + } + __syncthreads(); + + run_cumsum_strided(); + const int total_items = s_hist_buf[0][0]; + + for (int bin = tx; bin < kRadix; bin += NUM_THREADS) { + const int total_at_or_above = s_hist_buf[0][bin]; + const int strictly_above = (bin + 1 < kRadix) ? s_hist_buf[0][bin + 1] : 0; + if (total_at_or_above >= LOCAL_K && strictly_above < LOCAL_K) { + s_threshold_bin = bin; + s_last_remain = LOCAL_K - strictly_above; + } + } + __syncthreads(); + + if (total_items <= LOCAL_K) { + // Few-elements path: collect everything, pad rest. + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int slot = ::atomicAdd(&s_above_count, 1); + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } + __syncthreads(); + } else { + const int threshold_bin = s_threshold_bin; + + for (int i = tx; i < kRadix + 128; i += NUM_THREADS) { + s_hist_buf[0][i] = 0; + s_hist_buf[1][i] = 0; + } + __syncthreads(); + + // ============== Pass 2: gather above + sub-hist. ============== + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + if (bin > threshold_bin) { + const int slot = ::atomicAdd(&s_above_count, 1); + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } else if (bin == threshold_bin) { + const int sub_bin = static_cast((key >> 16) & 0xFF); + ::atomicAdd(&s_hist_buf[0][sub_bin], 1); + } + } + __syncthreads(); + + run_cumsum_strided(); + + const int last_remain = s_last_remain; + for (int bin = tx; bin < kRadix; bin += NUM_THREADS) { + const int total_at_or_above = s_hist_buf[0][bin]; + const int strictly_above = (bin + 1 < kRadix) ? s_hist_buf[0][bin + 1] : 0; + if (total_at_or_above >= last_remain && strictly_above < last_remain) { + s_sub_threshold_bin = bin; + s_sub_last_remain = last_remain - strictly_above; + s_strictly_above_sub = strictly_above; + } + } + __syncthreads(); + + const int sub_threshold_bin = s_sub_threshold_bin; + const int sub_last_remain = s_sub_last_remain; + const int strictly_above_sub_bn = s_strictly_above_sub; + const int above_base = s_above_count; + + // ============== Pass 3: sub-above + sub-at. ============== + for (int local_rank = tx; local_rank < group_len; local_rank += NUM_THREADS) { + int pos; + if (row_is_pow2) pos = compute_pos(local_rank, row_len, n, b_off, n_mask); + else pos = group_begin + local_rank; + const float raw = vortex_to_float(row_scores[pos]); + const float remapped = apply_transform_tmpl(raw, mapping_power); + const uint32_t key = convert_to_uint32(remapped); + const int bin = static_cast(key >> 24); + if (bin == threshold_bin) { + const int sub_bin = static_cast((key >> 16) & 0xFF); + if (sub_bin > sub_threshold_bin) { + const int rel = ::atomicAdd(&s_thresh_above_count, 1); + const int slot = above_base + rel; + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } else if (sub_bin == sub_threshold_bin) { + const int rel = ::atomicAdd(&s_thresh_at_count, 1); + if (rel < sub_last_remain) { + const int slot = above_base + strictly_above_sub_bn + rel; + if (slot < LOCAL_K) { + s_top_keys[slot] = key; + s_top_idx [slot] = row_idxmap[pos]; + } + } + } + } + } + __syncthreads(); + } + + // ============== Sort SLOTS_PADDED candidates with cub::BlockMergeSort. ============== + // The first LOCAL_K slots may have real data; padded slots have (0u, -1). + // Sort uses NT_SORT threads; only those threads load/store sort items. + if (tx < NT_SORT) { + uint32_t kk[IPT_SORT]; + int32_t vv[IPT_SORT]; + #pragma unroll + for (int k = 0; k < IPT_SORT; ++k) { + const int slot = tx * IPT_SORT + k; + kk[k] = s_top_keys[slot]; + vv[k] = s_top_idx [slot]; + } + LocalSortT(local_sort_smem).Sort(kk, vv, DescendingUint32{}); + #pragma unroll + for (int k = 0; k < IPT_SORT; ++k) { + const int slot = tx * IPT_SORT + k; + s_top_keys[slot] = kk[k]; + s_top_idx [slot] = vv[k]; + } + } + __syncthreads(); + + // ============== SPLITS == 1: direct write to sparse_kv_indices. ============== + if constexpr (SPLITS == 1) { + int32_t* out_idx = sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos; + for (int rank = tx; rank < topk_val; rank += NUM_THREADS) { + out_idx[rank] = s_top_idx[rank]; + } + return; + } + + // ============== SPLITS > 1: workspace, last-CTA barrier, merge. ============== + const int64_t part_off = (static_cast(b) * SPLITS + n) * LOCAL_K; + for (int i = tx; i < LOCAL_K; i += NUM_THREADS) { + partial_keys [part_off + i] = s_top_keys[i]; + partial_indices[part_off + i] = s_top_idx [i]; + } + + __threadfence(); + __syncthreads(); + if (tx == 0) { + const int old = ::atomicAdd(&done_counter[b], 1); + s_is_last = (old == SPLITS - 1) ? 1 : 0; + if (s_is_last) done_counter[b] = 0; // self-reset for next launch + } + __syncthreads(); + if (s_is_last == 0) return; + __threadfence(); + __syncthreads(); + + const int64_t row_off = static_cast(b) * SPLITS * LOCAL_K; + merge_block_sort_topk_midk( + partial_keys + row_off, partial_indices + row_off, + sparse_kv_indices + sparse_kv_indptr[b] + reserved_bos, + topk_val); + } + + // Mid-K capacity policy. Returns true iff (LOCAL_K, SPLITS) is supported. + inline bool midk_split_supported(int local_k, int splits) { + const int candidates = local_k * splits; + if (candidates > 4096) return false; + if (splits != 1 && splits != 2 && splits != 4 && splits != 8 && + splits != 16 && splits != 32) return false; + return true; + } + + // Pick LOCAL_K from K. We use the smallest power-of-two LOCAL_K >= K. + inline int midk_local_k_from_topk(int topk_val) { + if (topk_val <= 64) return 64; + if (topk_val <= 128) return 128; + if (topk_val <= 256) return 256; + if (topk_val <= 512) return 512; + return -1; // unsupported + } + + } // namespace + + #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + + // ============================================================================= + // Workspace API: zero hot-path at::empty allocations. + // + // topk_val >= 1024 → forwards to topk_output_sglang_fused without + // touching workspace tensors or done_counter. + // topk_val <= 32 → uses the K=30 random-split parallel path with + // forced_splits (if > 0) or pick_split_top30(). + // else → also forwards to fused (no specialised path here). + // + // partial_keys / partial_indices must each have at least + // eff_batch_size * SPLITS * kLocalK_Top30 = eff_batch_size * SPLITS * 32 + // int32 elements. done_counter must have at least eff_batch_size int32 + // elements; it is cleared with cudaMemsetAsync inside this call before the + // parallel kernel launches (and is NOT touched on the fused-fallback path). + // + // forced_splits encoding: + // <= 0 : use heuristic pick_split_top30(). + // 1 : single-CTA local sort path (for benchmarking). + // 2/4/8/16/32 : forced parallel split. + // anything else : TORCH_CHECK failure. + // ============================================================================= + void topk_output_adaptive_workspace( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& partial_keys, + at::Tensor& partial_indices, + at::Tensor& done_counter, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + const int64_t forced_splits, + const int64_t forced_partition, + const int64_t local_mode) + { + // ============== Fused fallback (no workspace touch) ============== + // K >= 1024: 32k -> 2048 lives here. Direct delegate. NO workspace check, + // NO memset, NO split kernel launch — this is the near-zero-overhead + // hot fast-path required for the K=2048 workload. + // + // K in (32, 1024) also routes here: those Ks have no specialised + // adaptive kernel and the fused baseline is the right path. NOTE: for + // K <= 32 (the K=30 path) we never come back to fused below — every + // adaptive sub-path stays on the split kernel. + if (topk_val >= kFusedFallbackTopK || topk_val > kMaxFinalK_Top30) { + topk_output_sglang_fused( + x, dense_kv_indptr, sparse_kv_indptr, + dense_kv_indices, sparse_kv_indices, + eff_batch_size, topk_val, + reserved_bos, reserved_eos, max_num_pages, + mapping_mode, mapping_power, std::nullopt, std::nullopt); + return; + } + + // ============== K <= 32 adaptive path (no fused fallback below) ============== + + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + TORCH_CHECK(topk_val > 0, "topk_val must be > 0"); + TORCH_CHECK(eff_batch_size >= 1, "eff_batch_size must be >= 1"); + TORCH_CHECK(max_num_pages >= 1, "max_num_pages must be >= 1"); + + // local_mode validation. -1 (or any negative) defaults to SELECT32_SORT32, + // which is the production mode (no NT*IPT capacity ceiling, supports the + // full pages={4096,8192,16384,32768} x splits={1..32} matrix). + int local_mode_int = static_cast(local_mode); + if (local_mode_int < 0) local_mode_int = LOCAL_SELECT32_SORT32; + TORCH_CHECK(local_mode_int == LOCAL_BLOCK_FULL_SORT || + local_mode_int == LOCAL_SELECT32_SORT32, + "local_mode must be 0 (BLOCK_FULL_SORT) or 1 (SELECT32_SORT32), got ", + local_mode_int); + + TORCH_CHECK( + mapping_mode == MAPPING_NONE || + mapping_mode == MAPPING_POWER || + mapping_mode == MAPPING_LOG || + mapping_mode == MAPPING_ASINH || + mapping_mode == MAPPING_LOG1P || + mapping_mode == MAPPING_TRUNC8 || + mapping_mode == MAPPING_ERF || + mapping_mode == MAPPING_TANH || + mapping_mode == MAPPING_SUBTRACT || + mapping_mode == MAPPING_EXP_STRETCH || + mapping_mode == MAPPING_SHIFT_POW2 || + mapping_mode == MAPPING_SHIFT_POW3 || + mapping_mode == MAPPING_LINEAR_STEEP || + mapping_mode == MAPPING_HALF_SQUARE || + mapping_mode == MAPPING_HALF_CUBE, + "topk_output_adaptive_workspace: mapping_mode=", mapping_mode, + " not supported."); + + // Resolve split count. K=30 NEVER falls back to fused: split=1 means + // single-CTA adaptive kernel, not the fused baseline. + int split; + if (forced_splits > 0) { + split = static_cast(forced_splits); + TORCH_CHECK(split == 1 || split == 2 || split == 4 || split == 8 || + split == 16 || split == 32, + "forced_splits must be one of {1,2,4,8,16,32}, got ", split); + } else { + split = pick_split_top30(eff_batch_size, max_num_pages); + } + + // Resolve partition mode. + int partition; + if (forced_partition >= 0) { + partition = static_cast(forced_partition); + TORCH_CHECK(partition == PART_AFFINE_RANDOM || + partition == PART_CONTIGUOUS || + partition == PART_STRIDED || + partition == PART_TILE_RANDOM_128 || + partition == PART_TILE_RANDOM_256, + "forced_partition must be 0=affine,1=contiguous,2=strided," + "3=tile_random_128,4=tile_random_256"); + } else { + partition = default_partition(split, max_num_pages); + } + + // Capacity check applies ONLY to BLOCK_FULL_SORT, which uses + // cub::BlockRadixSort and is bounded by NT*IPT static-smem footprint. + // SELECT32_SORT32 has no such ceiling (its inner loops are strided). + // + // K=30 must NEVER silently fall back to fused — if BLOCK_FULL_SORT can't + // fit the chunk, we fail loudly so the caller picks a finer split or + // switches to SELECT32_SORT32. + if (local_mode_int == LOCAL_BLOCK_FULL_SORT) { + const int chunk_max = static_cast((max_num_pages + split - 1) / split); + const int cap = split_capacity(split); + TORCH_CHECK(cap >= chunk_max, + "topk_output_adaptive_workspace: BLOCK_FULL_SORT split=", split, + " has NT*IPT=", cap, + " < required chunk_max=", chunk_max, + " (max_num_pages=", max_num_pages, + "). Use SELECT32_SORT32 (local_mode=1) or a finer split."); + } + + // From here we enter the parallel path. The split=1 forced case still + // reads partial_keys/partial_indices/done_counter args but does NOT + // touch them — we accept any tensor of the right dtype. + CHECK_CUDA(partial_keys); + CHECK_CUDA(partial_indices); + CHECK_CUDA(done_counter); + TORCH_CHECK(partial_keys.dtype() == at::kInt, + "partial_keys must be int32 (uint32 reinterpreted)"); + TORCH_CHECK(partial_indices.dtype() == at::kInt, "partial_indices must be int32"); + TORCH_CHECK(done_counter.dtype() == at::kInt, "done_counter must be int32"); + + if (split > 1) { + TORCH_CHECK(done_counter.numel() >= eff_batch_size, + "done_counter[", done_counter.numel(), + "] too small for eff_batch_size=", eff_batch_size); + const int64_t need = eff_batch_size * static_cast(split) * kLocalK_Top30; + TORCH_CHECK(partial_keys.numel() >= need, + "partial_keys too small: ", partial_keys.numel(), " < ", need); + TORCH_CHECK(partial_indices.numel() >= need, + "partial_indices too small: ", partial_indices.numel(), " < ", need); + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + // No cudaMemsetAsync(done_counter) — the kernel self-resets done_counter[b] + // = 0 from the last CTA's tx==0 thread, so subsequent launches see it + // already zero. Saves ~1-2 µs of CPU launch overhead per call. Caller + // contract: done_counter must be zero-initialized once at workspace + // allocation (at::zeros) and not touched by anyone else on this stream. + + uint32_t* part_keys_ptr = + reinterpret_cast(partial_keys.data_ptr()); + int32_t* part_idx_ptr = partial_indices.data_ptr(); + int32_t* done_ptr = done_counter.data_ptr(); + + dim3 grid(static_cast(eff_batch_size), + static_cast(split)); + + // ---- BLOCK_FULL_SORT macro chain (TopK30_RandomSplit_Parallel_Kernel) ---- + #define LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART) \ + do { \ + auto* fn = &TopK30_RandomSplit_Parallel_Kernel< \ + DTYPE, MODE_VAL, SPLITS_VAL, NT, IPT, PART>; \ + fn<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + part_keys_ptr, part_idx_ptr, done_ptr, \ + static_cast(topk_val), \ + static_cast(reserved_bos), \ + static_cast(reserved_eos), \ + mp); \ + } while (0) + + #define DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT) \ + do { \ + switch (partition) { \ + case PART_AFFINE_RANDOM: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_AFFINE_RANDOM); break; \ + case PART_CONTIGUOUS: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_CONTIGUOUS); break; \ + case PART_STRIDED: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_STRIDED); break; \ + case PART_TILE_RANDOM_128: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_TILE_RANDOM_128); break; \ + case PART_TILE_RANDOM_256: \ + LAUNCH_TOP30_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, IPT, PART_TILE_RANDOM_256); break; \ + default: TORCH_CHECK(false, "unreachable partition mode"); \ + } \ + } while (0) + + #define DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + switch (split) { \ + case 1: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 1, kCfg1.num_threads, kCfg1.items_per_thread); break; \ + case 2: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 2, kCfg2.num_threads, kCfg2.items_per_thread); break; \ + case 4: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 4, kCfg4.num_threads, kCfg4.items_per_thread); break; \ + case 8: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 8, kCfg8.num_threads, kCfg8.items_per_thread); break; \ + case 16: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 16, kCfg16.num_threads, kCfg16.items_per_thread); break; \ + case 32: DISPATCH_PART_BLOCK(DTYPE, PTR_EXPR, MODE_VAL, 32, kCfg32.num_threads, kCfg32.items_per_thread); break; \ + default: TORCH_CHECK(false, "unsupported split=", split); \ + } \ + } while (0) + + // ---- SELECT32_SORT32 macro chain (TopK30_RandomSplit_Select32_Kernel) ---- + #define LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART) \ + do { \ + auto* fn = &TopK30_RandomSplit_Select32_Kernel< \ + DTYPE, MODE_VAL, SPLITS_VAL, NT, PART>; \ + fn<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + part_keys_ptr, part_idx_ptr, done_ptr, \ + static_cast(topk_val), \ + static_cast(reserved_bos), \ + static_cast(reserved_eos), \ + mp); \ + } while (0) + + #define DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT) \ + do { \ + switch (partition) { \ + case PART_AFFINE_RANDOM: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_AFFINE_RANDOM); break; \ + case PART_CONTIGUOUS: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_CONTIGUOUS); break; \ + case PART_STRIDED: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_STRIDED); break; \ + case PART_TILE_RANDOM_128: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_TILE_RANDOM_128); break; \ + case PART_TILE_RANDOM_256: \ + LAUNCH_TOP30_SELECT(DTYPE, PTR_EXPR, MODE_VAL, SPLITS_VAL, NT, PART_TILE_RANDOM_256); break; \ + default: TORCH_CHECK(false, "unreachable partition mode"); \ + } \ + } while (0) + + #define DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + switch (split) { \ + case 1: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 1, kSelCfg1.num_threads); break; \ + case 2: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 2, kSelCfg2.num_threads); break; \ + case 4: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 4, kSelCfg4.num_threads); break; \ + case 8: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 8, kSelCfg8.num_threads); break; \ + case 16: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 16, kSelCfg16.num_threads); break; \ + case 32: DISPATCH_PART_SELECT(DTYPE, PTR_EXPR, MODE_VAL, 32, kSelCfg32.num_threads); break; \ + default: TORCH_CHECK(false, "unsupported split=", split); \ + } \ + } while (0) + + // Top-level: choose the local-mode chain, then mapping_mode → split → partition. + // MAPPING_TRUNC8 shares its semantics with MAPPING_NONE (identity transform). + // Routing both to MAPPING_NONE saves one template instantiation per chain. + #define DISPATCH_MODE(DTYPE, PTR_EXPR) \ + do { \ + if (local_mode_int == LOCAL_SELECT32_SORT32) { \ + switch (mapping_mode) { \ + case MAPPING_NONE: \ + case MAPPING_TRUNC8: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: DISPATCH_SPLIT_SELECT(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + default: TORCH_CHECK(false, "unreachable mapping_mode"); \ + } \ + } else { \ + switch (mapping_mode) { \ + case MAPPING_NONE: \ + case MAPPING_TRUNC8: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_NONE); break; \ + case MAPPING_POWER: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ + case MAPPING_LOG: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_LOG); break; \ + case MAPPING_ASINH: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ + case MAPPING_LOG1P: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ + case MAPPING_ERF: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ + case MAPPING_TANH: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ + case MAPPING_SUBTRACT: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ + case MAPPING_EXP_STRETCH: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ + case MAPPING_SHIFT_POW2: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ + case MAPPING_SHIFT_POW3: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ + case MAPPING_LINEAR_STEEP: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ + case MAPPING_HALF_SQUARE: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_HALF_SQUARE); break; \ + case MAPPING_HALF_CUBE: DISPATCH_SPLIT_BLOCK(DTYPE, PTR_EXPR, MAPPING_HALF_CUBE); break; \ + default: TORCH_CHECK(false, "unreachable mapping_mode"); \ + } \ + } \ + } while (0) + + if (x.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MODE(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); + } else if (x.scalar_type() == at::ScalarType::Float) { + DISPATCH_MODE(float, x.data_ptr()); + } else { + TORCH_CHECK(false, "topk_output_adaptive_workspace: unsupported dtype ", + x.scalar_type()); + } + + #undef DISPATCH_MODE + #undef DISPATCH_SPLIT_SELECT + #undef DISPATCH_PART_SELECT + #undef LAUNCH_TOP30_SELECT + #undef DISPATCH_SPLIT_BLOCK + #undef DISPATCH_PART_BLOCK + #undef LAUNCH_TOP30_BLOCK + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "topk_output_adaptive_workspace launch failed: ", + ::cudaGetErrorString(rc)); + } + + // ============================================================================= + // Legacy entry point — allocates workspace internally and forwards. + // + // NOTE: this path performs at::empty allocations and is therefore NOT a + // reference for latency benchmarks. New callers should use + // topk_output_adaptive_workspace with preallocated workspace. + // ============================================================================= + void topk_output_adaptive( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power) + { + // Workspace big enough for the largest split this kernel may pick (32). + constexpr int64_t kMaxSplit = 32; + const int64_t ws_elems = eff_batch_size * kMaxSplit * kLocalK_Top30; + + auto opts_i32 = at::TensorOptions().device(x.device()).dtype(at::kInt); + at::Tensor partial_keys = at::empty({ws_elems}, opts_i32); + at::Tensor partial_indices = at::empty({ws_elems}, opts_i32); + at::Tensor done_counter = at::empty({eff_batch_size}, opts_i32); + + topk_output_adaptive_workspace( + x, dense_kv_indptr, sparse_kv_indptr, dense_kv_indices, + sparse_kv_indices, partial_keys, partial_indices, done_counter, + eff_batch_size, topk_val, reserved_bos, reserved_eos, + max_num_pages, mapping_mode, mapping_power, + /*forced_splits=*/-1, + /*forced_partition=*/-1, + /*local_mode=*/LOCAL_SELECT32_SORT32); + } + + +// ============================================================================= +// Mid-K (K in {64, 128, 256, 512}) adaptive split entry point. +// +// Separate from topk_output_adaptive_workspace so the K=30 production path +// stays untouched. workspace tensors must be sized for +// eff_batch_size * SPLITS * LOCAL_K +// where LOCAL_K is the smallest power of two >= topk_val (max 512), and +// SPLITS is forced_splits if > 0, else 1. +// +// Dispatch contract: +// topk_val < 64 or > 512 → TORCH_CHECK failure (use the K=30 path or fused). +// forced_splits encoding: +// <= 0 : default policy (currently split=1; sweep will inform a heuristic). +// 1/2/4/8/16/32 : forced split, must satisfy midk_split_supported(). +// ============================================================================= +void topk_output_adaptive_workspace_midk( + const at::Tensor& x, + const at::Tensor& dense_kv_indptr, + const at::Tensor& sparse_kv_indptr, + const at::Tensor& dense_kv_indices, + at::Tensor& sparse_kv_indices, + at::Tensor& partial_keys, + at::Tensor& partial_indices, + at::Tensor& done_counter, + const int64_t eff_batch_size, + const int64_t topk_val, + const int64_t reserved_bos, + const int64_t reserved_eos, + const int64_t max_num_pages, + const int64_t mapping_mode, + const double mapping_power, + const int64_t forced_splits) +{ + CHECK_CUDA(x); + CHECK_CUDA(dense_kv_indptr); + CHECK_CUDA(sparse_kv_indptr); + CHECK_CUDA(dense_kv_indices); + CHECK_CUDA(sparse_kv_indices); + + TORCH_CHECK(topk_val >= 64 && topk_val <= 512, + "topk_output_adaptive_workspace_midk: topk_val=", topk_val, + " out of range [64, 512]. Use topk_output_adaptive_workspace " + "for K<=32 or topk_output_sglang_fused for K>512."); + TORCH_CHECK(eff_batch_size >= 1, "eff_batch_size must be >= 1"); + TORCH_CHECK(max_num_pages >= 1, "max_num_pages must be >= 1"); + + const int local_k = midk_local_k_from_topk(static_cast(topk_val)); + TORCH_CHECK(local_k > 0, "unreachable: midk_local_k_from_topk failed for K=", topk_val); + + // Mid-K mappings: NONE / TRUNC8 only for now (kept template count low). + // POWER/LOG/etc. are easy to add later once we measure their value. + TORCH_CHECK(mapping_mode == MAPPING_NONE || mapping_mode == MAPPING_TRUNC8, + "topk_output_adaptive_workspace_midk: mapping_mode=", + mapping_mode, " not yet supported (use NONE or TRUNC8)."); + + int split; + if (forced_splits > 0) { + split = static_cast(forced_splits); + TORCH_CHECK(midk_split_supported(local_k, split), + "topk_output_adaptive_workspace_midk: split=", split, + " not supported for LOCAL_K=", local_k, + " (would need ", split * local_k, " merge candidates, max 4096)."); + } else { + // Sweep-driven default. From bench_results/midk_best_adaptive_p50.csv: + // + // pages <= 65536 : adaptive loses every cell vs fused on p50 — but a + // user calling this entry point explicitly is asking + // for adaptive anyway, so use split=1 (smallest gap). + // pages > 65536 : fused unsupported (smem ceiling). Best splits: + // K=64 → 16 + // K=128 → 16 + // K=256 → 2 + // K=512 → 4 + // + // forced_splits > 0 still overrides this, e.g. for benchmarking. + if (max_num_pages > 65536) { + switch (local_k) { + case 64: split = 16; break; + case 128: split = 16; break; + case 256: split = 2; break; + case 512: split = 4; break; + default: split = 1; + } + } else { + split = 1; + } + } + + CHECK_CUDA(partial_keys); + CHECK_CUDA(partial_indices); + CHECK_CUDA(done_counter); + TORCH_CHECK(partial_keys.dtype() == at::kInt, "partial_keys must be int32"); + TORCH_CHECK(partial_indices.dtype() == at::kInt, "partial_indices must be int32"); + TORCH_CHECK(done_counter.dtype() == at::kInt, "done_counter must be int32"); + + if (split > 1) { + TORCH_CHECK(done_counter.numel() >= eff_batch_size, + "done_counter[", done_counter.numel(), + "] too small for eff_batch_size=", eff_batch_size); + const int64_t need = eff_batch_size * static_cast(split) * local_k; + TORCH_CHECK(partial_keys.numel() >= need, + "partial_keys too small: ", partial_keys.numel(), " < ", need); + TORCH_CHECK(partial_indices.numel() >= need, + "partial_indices too small: ", partial_indices.numel(), " < ", need); + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + const float mp = static_cast(mapping_power); + + // No cudaMemsetAsync — kernel self-resets done_counter (see midk kernel + // and dispatcher comment for topk_output_adaptive_workspace). + + uint32_t* part_keys_ptr = + reinterpret_cast(partial_keys.data_ptr()); + int32_t* part_idx_ptr = partial_indices.data_ptr(); + int32_t* done_ptr = done_counter.data_ptr(); + + dim3 grid(static_cast(eff_batch_size), + static_cast(split)); + + // NT scales inversely with SPLITS so the per-CTA scan loop has roughly + // the same iteration count regardless of split count. Mirror the K=30 + // kSelCfg ladder. With NT=128 at SPLITS=1, a 65k-page row would force + // 512 iters/thread/pass — way slower than the ~64 iters fused achieves + // with NT=1024 single-CTA. Match fused throughput at split=1. + // + // SPLITS=1 : NT=1024 (chunk = full row) + // SPLITS=2 : NT=512 + // SPLITS=4 : NT=256 + // SPLITS=8 : NT=128 + // SPLITS=16 : NT=128 + // SPLITS=32 : NT=128 + + #define LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, SPLITS_VAL, NT_VAL) \ + do { \ + auto* fn = &TopKMidK_RandomSplit_SelectK_Kernel< \ + DTYPE, MODE_VAL, LOCAL_K_VAL, SPLITS_VAL, NT_VAL, PART_CONTIGUOUS>; \ + fn<<>>( \ + PTR_EXPR, \ + dense_kv_indptr.data_ptr(), \ + sparse_kv_indptr.data_ptr(), \ + dense_kv_indices.data_ptr(), \ + sparse_kv_indices.data_ptr(), \ + part_keys_ptr, part_idx_ptr, done_ptr, \ + static_cast(topk_val), \ + static_cast(reserved_bos), \ + static_cast(reserved_eos), \ + mp); \ + } while (0) + + #define DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL) \ + do { \ + switch (split) { \ + case 1: LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 1, 1024); break; \ + case 2: LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 2, 512); break; \ + case 4: LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 4, 256); break; \ + case 8: LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 8, 128); break; \ + case 16: \ + if constexpr ((LOCAL_K_VAL) * 16 <= 4096) \ + LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 16, 128); \ + else \ + TORCH_CHECK(false, "midk: split=16 unsupported for LOCAL_K=", LOCAL_K_VAL);\ + break; \ + case 32: \ + if constexpr ((LOCAL_K_VAL) * 32 <= 4096) \ + LAUNCH_MIDK(DTYPE, PTR_EXPR, MODE_VAL, LOCAL_K_VAL, 32, 128); \ + else \ + TORCH_CHECK(false, "midk: split=32 unsupported for LOCAL_K=", LOCAL_K_VAL);\ + break; \ + default: TORCH_CHECK(false, "midk: unsupported split=", split); \ + } \ + } while (0) + + #define DISPATCH_LK_MIDK(DTYPE, PTR_EXPR, MODE_VAL) \ + do { \ + switch (local_k) { \ + case 64: DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, 64); break; \ + case 128: DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, 128); break; \ + case 256: DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, 256); break; \ + case 512: DISPATCH_SPLIT_MIDK(DTYPE, PTR_EXPR, MODE_VAL, 512); break; \ + default: TORCH_CHECK(false, "midk: unreachable LOCAL_K=", local_k); \ + } \ + } while (0) + + #define DISPATCH_MIDK(DTYPE, PTR_EXPR) \ + do { \ + /* MAPPING_TRUNC8 aliases MAPPING_NONE; same template instantiation. */ \ + DISPATCH_LK_MIDK(DTYPE, PTR_EXPR, MAPPING_NONE); \ + } while (0) + + if (x.scalar_type() == at::ScalarType::BFloat16) { + DISPATCH_MIDK(__nv_bfloat16, + reinterpret_cast<__nv_bfloat16*>(x.data_ptr())); + } else if (x.scalar_type() == at::ScalarType::Float) { + DISPATCH_MIDK(float, x.data_ptr()); + } else { + TORCH_CHECK(false, "topk_output_adaptive_workspace_midk: unsupported dtype ", + x.scalar_type()); + } + + #undef DISPATCH_MIDK + #undef DISPATCH_LK_MIDK + #undef DISPATCH_SPLIT_MIDK + #undef LAUNCH_MIDK + + const auto rc = cudaGetLastError(); + TORCH_CHECK(rc == cudaSuccess, + "topk_output_adaptive_workspace_midk launch failed: ", + ::cudaGetErrorString(rc)); +} diff --git a/csrc/topk_sglang_parallel.cu b/csrc/topk_sglang_parallel.cu deleted file mode 100644 index e728940d..00000000 --- a/csrc/topk_sglang_parallel.cu +++ /dev/null @@ -1,639 +0,0 @@ -/** - * Vortex TopK — single-kernel parallel+merge pipeline. - * - * ONE kernel launch. Per-chunk selection and cross-chunk merge both run - * inside the same grid-(N, Batch) launch. The last-arriving CTA for - * each batch (detected by a program-lifetime __device__ done-counter + - * atomicInc wrap-around) carries out the merge — no second launch, no - * per-call cudaMemset for barrier state. - * - * Correctness: - * Stage 1 per-chunk uses ONE 8-bit radix histogram + ONE 8-bit - * refinement round on the threshold bin (16 bits of selection - * precision). For bf16 input (8 mantissa bits effective), this is - * lossless — two items with the same 16-bit key are bit-identical as - * bf16 values. - * - * Stage 2 merge operates on N*K pre-remapped keys in shared memory - * and uses the same 8-bit-hist + 8-bit-refine pattern, which is - * strictly sufficient to pick the correct top-K from the union. - * - * Low-overhead primitives: - * - Warp-level ballot+popc compaction on the "bin > threshold" path - * so each warp issues ONE atomicAdd on the block counter instead - * of one per thread. - * - Program-lifetime __device__ done-counter sized for realistic - * batch×head counts; atomicInc wraps back to 0 at num_chunks so - * there's no memset on the hot path. - * - Vectorised float4/int4 loads from global → smem in the merge. - * - * Supported mapping modes (IDs from csrc/topk_mapping.cuh): - * 3=POWER, 6=ASINH, 7=LOG1P, 9=ERF, 10=TANH, 11=SUBTRACT, - * 13=EXP_STRETCH, 15=SHIFT_POW2, 16=SHIFT_POW3, 17=LINEAR_STEEP. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include - -#include "register.h" - -namespace { - -// ---- Launch constants ------------------------------------------------------ - -constexpr int kThreadsPerBlock = 1024; -constexpr int kWarpSize = 32; -constexpr int RADIX = 256; -constexpr size_t kMaxDynSmem = 96 * 1024; -constexpr int VORTEX_MAX_TOPK = 2048; - -// Stage-2 holds N*K (key, idx) pairs in smem = 8 B/item. -constexpr int kMergeCap = 8192; - -// Max batch the single kernel can sequence. Sized for realistic -// bs×heads (decode). __device__ globals are zero-initialised at -// program start; atomicInc wrap-around keeps each entry at 0 between -// launches, so no host-side memset on the hot path. -constexpr int kMaxBatch = 8192; -__device__ unsigned int g_done_counter[kMaxBatch]; - -// ---- Device helpers -------------------------------------------------------- - -__device__ __forceinline__ uint32_t convert_to_uint32(float x) { - uint32_t bits = __float_as_uint(x); - return (bits & 0x80000000u) ? ~bits : (bits | 0x80000000u); -} - -// Required symbol for topk_mapping.cuh's compute_stage1_bin. Not used -// directly by the kernel body here, but the header includes a forward -// declaration that resolves against this definition at link time. -__device__ __forceinline__ uint8_t convert_to_uint8(float x) { - __half h = __float2half_rn(x); - uint16_t bits = __half_as_ushort(h); - uint16_t key = (bits & 0x8000) ? static_cast(~bits) - : static_cast(bits | 0x8000); - return static_cast(key >> 8); -} - -template -__device__ __forceinline__ float vortex_to_float(T x); -template <> -__device__ __forceinline__ float vortex_to_float(float x) { return x; } -template <> -__device__ __forceinline__ float vortex_to_float<__nv_bfloat16>(__nv_bfloat16 x) { - return __bfloat162float(x); -} - -#include "topk_mapping.cuh" - -// ============================================================================ -// 8-step suffix cumsum over 256 bins. After the call s_hist[0][i] is -// the count of items with bin >= i (monotone non-increasing). -// ============================================================================ -__device__ __forceinline__ void run_cumsum_256(int s_hist[2][RADIX + 128]) { - const int tx = threadIdx.x; -#pragma unroll 8 - for (int i = 0; i < 8; ++i) { - static_assert(1 << 8 == RADIX); - if (C10_LIKELY(tx < RADIX)) { - const int j = 1 << i; - const int k = i & 1; - int value = s_hist[k][tx]; - if (tx < RADIX - j) value += s_hist[k][tx + j]; - s_hist[k ^ 1][tx] = value; - } - __syncthreads(); - } -} - -// ============================================================================ -// Warp-level ballot+popc compaction. -// -// Every participating thread offers a boolean `selected`. Exactly ONE -// atomicAdd per warp — issued by the first active lane — reserves -// `warp_count` slots; other selected lanes derive their slot via a -// popc prefix sum. Safe when called from inside a divergent region -// (uses __activemask(), not a fixed all-ones mask). -// ============================================================================ -__device__ __forceinline__ int warp_compact_slot(bool selected, int* s_counter) { - const uint32_t mask = __activemask(); - const uint32_t ballot = __ballot_sync(mask, selected); - const int lane = threadIdx.x & (kWarpSize - 1); - const int warp_count = __popc(ballot); - const int rank_in_warp = __popc(ballot & ((1u << lane) - 1u)); - - const int first_lane = __ffs(mask) - 1; - int base = 0; - if (lane == first_lane) { - base = (warp_count > 0) ? ::atomicAdd(s_counter, warp_count) : 0; - } - base = __shfl_sync(mask, base, first_lane); - return selected ? (base + rank_in_warp) : -1; -} - -// ============================================================================ -// Combined kernel — Stage 1 (per-chunk) + barrier + Stage 2 (merge). -// -// Grid = (Batch, N). One CTA per (batch, chunk). -// Block = kThreadsPerBlock = 1024. -// -// Shared-memory layout (reused across phases): -// Phase 1 needs: -// s_remapped[chunk_size] (float) — cached apply_transform output. -// s_bins[chunk_size] (uint8) — cached coarse bin. -// Merge needs: -// s_scores[N*K] (float) — pair buffer, loaded vectorised. -// s_indices[N*K] (int32) — pair buffer. -// kSmemBytes is sized to host max of both. -// -// Sync between phases: -// After Phase 1's workspace writes, __threadfence() publishes them, -// then thread 0 does `atomicInc(&g_done_counter[bx], N-1)` which -// cycles 0→1→…→N-1→0 so no reset is needed between calls. The CTA -// whose returned `old == N-1` is the last one — it falls through -// into the merge; other CTAs return. -// ============================================================================ -template -__global__ __launch_bounds__(kThreadsPerBlock) -void TopK_Parallel_Kernel( - const ScoreT* __restrict__ score, // [Batch, N, chunk_size] - int32_t* __restrict__ global_idx, // [Batch, K] - float* __restrict__ partial_keys, // [Batch, N, K] workspace - int32_t* __restrict__ partial_idx, // [Batch, N, K] workspace - int N, - int chunk_size, - int K, - float mapping_power) -{ - const int b = blockIdx.x; - const int n = blockIdx.y; - const int tx = threadIdx.x; - - // Addresses for this CTA's chunk slice and its slot in the workspace. - const ScoreT* chunk_in = score + (static_cast(b) * N + n) * chunk_size; - float* chunk_keys_out = partial_keys + (static_cast(b) * N + n) * K; - int32_t* chunk_idx_out = partial_idx + (static_cast(b) * N + n) * K; - const int32_t idx_base = n * chunk_size; // batch-local offset - - // ---------------------------------------------------------------- smem - extern __shared__ char smem_raw[]; - - // Shared-memory counters / histogram live in static smem so the - // Phase-1 and merge phases can share the same dynamic pool. - alignas(128) __shared__ int s_hist_buf[2][RADIX + 128]; - alignas(128) __shared__ int s_counter; - alignas(128) __shared__ int s_threshold_bin; - alignas(128) __shared__ int s_sub_threshold_bin; - alignas(128) __shared__ int s_last_remain; - alignas(128) __shared__ int s_is_last; - auto& s_hist = s_hist_buf[0]; - - // ========================================================================= - // Phase 1: per-chunk TopK via 8-bit radix + 8-bit refinement. - // ========================================================================= - // - // Dynamic smem region used as: - // s_remapped : chunk_size * 4 B (cached apply_transform output) - // s_bins : chunk_size * 1 B (cached Stage-1 bin) - // - // Refinement is a second 8-bit bucket on bits [23:16] of the - // sign-flipped u32 key, used to refine the threshold bin. 8 + 8 = - // 16 bits of selection precision → lossless for bf16. - float* s_remapped = reinterpret_cast(smem_raw); - uint8_t* s_bins = reinterpret_cast(s_remapped + chunk_size); - - // ---- Degenerate chunk_size <= K : emit everything as-is. ------------- - if (chunk_size <= K) { - for (int i = tx; i < K; i += blockDim.x) { - if (i < chunk_size) { - const float raw = vortex_to_float(chunk_in[i]); - chunk_keys_out[i] = apply_transform_tmpl(raw, mapping_power); - chunk_idx_out [i] = i + idx_base; - } else { - chunk_keys_out[i] = -CUDART_INF_F; - chunk_idx_out [i] = -1; - } - } - } else { - // ---- Histogram pass 1: transform + bucket; cache both to smem. ---- - if (tx < RADIX + 1) s_hist[tx] = 0; - if (tx == 0) { s_counter = 0; s_threshold_bin = -1; s_last_remain = 0; } - __syncthreads(); - - for (int idx = tx; idx < chunk_size; idx += blockDim.x) { - const float raw = vortex_to_float(chunk_in[idx]); - const float remapped = apply_transform_tmpl(raw, mapping_power); - const uint32_t b32 = convert_to_uint32(remapped); - const int bin = (b32 >> 24) & 0xFF; - s_remapped[idx] = remapped; - s_bins [idx] = static_cast(bin); - ::atomicAdd(&s_hist[bin], 1); - } - __syncthreads(); - - run_cumsum_256(s_hist_buf); - - if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { - s_threshold_bin = tx; - s_last_remain = K - s_hist[tx + 1]; - } - __syncthreads(); - const int threshold_bin = s_threshold_bin; - - // ---- Emit bin > threshold (warp-popc) and build refinement hist. ---- - if (tx < RADIX + 1) s_hist[tx] = 0; - __syncthreads(); - - const int num_iters = (chunk_size + blockDim.x - 1) / blockDim.x; - for (int it = 0; it < num_iters; ++it) { - const int idx = it * blockDim.x + tx; - const bool in_range = (idx < chunk_size); - int bin = -1; - if (in_range) bin = static_cast(s_bins[idx]); - const bool take_above = in_range && (bin > threshold_bin); - - const int slot = warp_compact_slot(take_above, &s_counter); - if (take_above) { - chunk_keys_out[slot] = s_remapped[idx]; - chunk_idx_out [slot] = idx + idx_base; - } else if (in_range && bin == threshold_bin) { - const uint32_t b32 = convert_to_uint32(s_remapped[idx]); - const int sub_bin = (b32 >> 16) & 0xFF; - ::atomicAdd(&s_hist[sub_bin], 1); - } - } - __syncthreads(); - - // ---- Refinement cumsum → sub-threshold bin. ------------------------ - run_cumsum_256(s_hist_buf); - if (tx < RADIX && s_hist[tx] > s_last_remain - && s_hist[tx + 1] <= s_last_remain) { - s_sub_threshold_bin = tx; - // budget for items at the sub-threshold bin - s_last_remain = s_last_remain - s_hist[tx + 1]; - } - if (tx == 0 && s_sub_threshold_bin == -1) { - // Only possible if last_remain == 0 (bin > threshold already emitted - // exactly K items). Nothing more to do; make the sub bin a sentinel. - s_sub_threshold_bin = RADIX; // no sub-threshold bin - } - __syncthreads(); - const int sub_threshold_bin = s_sub_threshold_bin; - - // ---- Emit threshold-bin items using sub-threshold logic. ---------- - for (int it = 0; it < num_iters; ++it) { - const int idx = it * blockDim.x + tx; - const bool in_range = (idx < chunk_size); - int bin = -1; - if (in_range) bin = static_cast(s_bins[idx]); - int sub_bin = -1; - if (in_range && bin == threshold_bin) { - const uint32_t b32 = convert_to_uint32(s_remapped[idx]); - sub_bin = (b32 >> 16) & 0xFF; - } - - const bool take_sub_above = (sub_bin > sub_threshold_bin); - const int slot = warp_compact_slot(take_sub_above, &s_counter); - if (take_sub_above) { - chunk_keys_out[slot] = s_remapped[idx]; - chunk_idx_out [slot] = idx + idx_base; - } else if (sub_bin == sub_threshold_bin) { - const int pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) { - chunk_keys_out[K - pos] = s_remapped[idx]; - chunk_idx_out [K - pos] = idx + idx_base; - } - } - } - __syncthreads(); - } - - // ========================================================================= - // Barrier: publish this CTA's workspace writes and atomicInc the - // per-batch done-counter. The CTA that sees old == N-1 is the last - // arriving one; every other CTA returns here. - // ========================================================================= - __threadfence(); - __syncthreads(); - if (tx == 0) { - const unsigned int old = ::atomicInc( - &g_done_counter[b], static_cast(N - 1)); - s_is_last = (old == static_cast(N - 1)) ? 1 : 0; - } - __syncthreads(); - if (s_is_last == 0) return; - - // ========================================================================= - // Phase 2 (merge, only in last-arriving CTA): - // load N*K candidates into smem (vectorised) → - // 8-bit histogram in smem → - // threshold → warp-popc emit above + tie-bin refinement. - // ========================================================================= - const int total = N * K; - const float* keys_in = partial_keys + static_cast(b) * total; - const int32_t* idx_in = partial_idx + static_cast(b) * total; - int32_t* out_idx = global_idx + static_cast(b) * K; - - // Reuse the same dynamic smem region as Phase 1 — Phase 1's caches - // are dead now. Layout: [ s_scores : total floats | s_indices : total int32 ]. - float* s_scores = reinterpret_cast(smem_raw); - int32_t* s_indices = reinterpret_cast(s_scores + total); - - // Vectorised 128-bit loads when `total` is a multiple of 4. - if ((total & 3) == 0) { - const float4* keys_v = reinterpret_cast(keys_in); - const int4* idx_v = reinterpret_cast (idx_in); - float4* ss_v = reinterpret_cast (s_scores); - int4* si_v = reinterpret_cast (s_indices); - const int total4 = total >> 2; - for (int i = tx; i < total4; i += blockDim.x) { - ss_v[i] = keys_v[i]; - si_v[i] = idx_v [i]; - } - } else { - for (int i = tx; i < total; i += blockDim.x) { - s_scores [i] = keys_in[i]; - s_indices[i] = idx_in [i]; - } - } - - if (tx < RADIX + 1) s_hist[tx] = 0; - if (tx == 0) { - s_counter = 0; - s_threshold_bin = -1; - s_sub_threshold_bin = -1; - s_last_remain = 0; - } - __syncthreads(); - - // (2) 8-bit histogram in smem. - const int num_iters_m = (total + blockDim.x - 1) / blockDim.x; - for (int it = 0; it < num_iters_m; ++it) { - const int i = it * blockDim.x + tx; - if (i < total && s_indices[i] >= 0) { - const uint32_t b32 = convert_to_uint32(s_scores[i]); - const int bin = (b32 >> 24) & 0xFF; - ::atomicAdd(&s_hist[bin], 1); - } - } - __syncthreads(); - - run_cumsum_256(s_hist_buf); - - // Fast path: no threshold search needed when valid_count ≤ K. - const int valid_count = s_hist[0]; - if (valid_count <= K) { - for (int it = 0; it < num_iters_m; ++it) { - const int i = it * blockDim.x + tx; - const bool take = (i < total) && (s_indices[i] >= 0); - const int slot = warp_compact_slot(take, &s_counter); - if (take) out_idx[slot] = s_indices[i]; - } - return; - } - - if (tx < RADIX && s_hist[tx] > K && s_hist[tx + 1] <= K) { - s_threshold_bin = tx; - s_last_remain = K - s_hist[tx + 1]; - } - __syncthreads(); - const int threshold_bin_m = s_threshold_bin; - - // (3) Emit above threshold via warp-popc; build sub-bin histogram on - // bits [23:16] for the tie-bin refinement. - if (tx < RADIX + 1) s_hist[tx] = 0; - __syncthreads(); - - for (int it = 0; it < num_iters_m; ++it) { - const int i = it * blockDim.x + tx; - bool in_valid = false; - int bin = -1; - uint32_t b32 = 0; - if (i < total) { - const int32_t idx = s_indices[i]; - if (idx >= 0) { - in_valid = true; - b32 = convert_to_uint32(s_scores[i]); - bin = (b32 >> 24) & 0xFF; - } - } - const bool take_above = in_valid && (bin > threshold_bin_m); - const int slot = warp_compact_slot(take_above, &s_counter); - if (take_above) { - out_idx[slot] = s_indices[i]; - } else if (in_valid && bin == threshold_bin_m) { - const int sub_bin = (b32 >> 16) & 0xFF; - ::atomicAdd(&s_hist[sub_bin], 1); - } - } - __syncthreads(); - - // (4) Refinement cumsum → sub-threshold bin. - run_cumsum_256(s_hist_buf); - if (tx < RADIX && s_hist[tx] > s_last_remain - && s_hist[tx + 1] <= s_last_remain) { - s_sub_threshold_bin = tx; - s_last_remain = s_last_remain - s_hist[tx + 1]; - } - if (tx == 0 && s_sub_threshold_bin == -1) { - s_sub_threshold_bin = RADIX; // no tie-bin refinement needed - } - __syncthreads(); - const int sub_threshold_bin_m = s_sub_threshold_bin; - - // (5) Emit tie-bin items via warp-popc + sub-threshold budget. - for (int it = 0; it < num_iters_m; ++it) { - const int i = it * blockDim.x + tx; - bool in_threshold = false; - int sub_bin = -1; - if (i < total) { - const int32_t idx = s_indices[i]; - if (idx >= 0) { - const uint32_t b32 = convert_to_uint32(s_scores[i]); - const int bin = (b32 >> 24) & 0xFF; - if (bin == threshold_bin_m) { - in_threshold = true; - sub_bin = (b32 >> 16) & 0xFF; - } - } - } - const bool take_sub_above = in_threshold && (sub_bin > sub_threshold_bin_m); - const int slot = warp_compact_slot(take_sub_above, &s_counter); - if (take_sub_above) { - out_idx[slot] = s_indices[i]; - } else if (in_threshold && sub_bin == sub_threshold_bin_m) { - const int pos = ::atomicAdd(&s_last_remain, -1); - if (pos > 0) out_idx[K - pos] = s_indices[i]; - } - } -} - -// ---- setup_kernel_smem_once ------------------------------------------------ - -template -void setup_kernel_smem_once() { - [[maybe_unused]] - static const auto result = [] { - return ::cudaFuncSetAttribute( - f, ::cudaFuncAttributeMaxDynamicSharedMemorySize, max_dynamic_smem); - }(); - TORCH_CHECK(result == cudaSuccess, - "fast_fused_topk_merge setup failed: ", - ::cudaGetErrorString(result)); -} - -} // namespace - -#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") - -// ============================================================================ -// Host entry point. -// -// score [batch_size, num_chunks, chunk_size] bf16 or f32 -// global_topk_indices [batch_size, topk_val] int32 (output) -// -// ONE kernel launch. The per-chunk selection (Phase 1) and the -// cross-chunk merge (Phase 2) are fused in TopK_Parallel_Kernel via a -// last-CTA-wins atomicInc barrier. A per-call workspace holds the -// [batch, N, K] partial top-K that the last CTA reads from; the -// done-counter is a program-lifetime __device__ global so nothing -// needs memsetting on the hot path. -// ============================================================================ -void fast_fused_topk_merge( - const at::Tensor& score, - at::Tensor& global_topk_indices, - const int64_t batch_size, - const int64_t num_chunks, - const int64_t chunk_size, - const int64_t topk_val, - const int64_t mapping_mode, - const double mapping_power) -{ - CHECK_CUDA(score); - CHECK_CUDA(global_topk_indices); - - TORCH_CHECK(topk_val > 0 && topk_val <= VORTEX_MAX_TOPK, - "fast_fused_topk_merge: topk_val=", topk_val, - " must be in (0, ", VORTEX_MAX_TOPK, "]"); - TORCH_CHECK(num_chunks >= 1, "num_chunks must be >= 1"); - TORCH_CHECK(batch_size >= 1, "batch_size must be >= 1"); - TORCH_CHECK(batch_size <= kMaxBatch, - "fast_fused_topk_merge: batch_size ", batch_size, - " exceeds the __device__ done-counter cap (", kMaxBatch, ")"); - TORCH_CHECK(chunk_size >= 1, "chunk_size must be >= 1"); - TORCH_CHECK(num_chunks * topk_val <= kMergeCap, - "fast_fused_topk_merge: num_chunks*topk_val (", - num_chunks * topk_val, ") exceeds merge cap (", kMergeCap, - "). Reduce num_chunks or topk_val."); - TORCH_CHECK(global_topk_indices.scalar_type() == at::kInt, - "global_topk_indices must be int32"); - TORCH_CHECK(global_topk_indices.numel() >= batch_size * topk_val, - "global_topk_indices is too small for batch_size * topk_val"); - - TORCH_CHECK( - mapping_mode == MAPPING_POWER || - mapping_mode == MAPPING_ASINH || - mapping_mode == MAPPING_LOG1P || - mapping_mode == MAPPING_ERF || - mapping_mode == MAPPING_TANH || - mapping_mode == MAPPING_SUBTRACT || - mapping_mode == MAPPING_EXP_STRETCH || - mapping_mode == MAPPING_SHIFT_POW2 || - mapping_mode == MAPPING_SHIFT_POW3 || - mapping_mode == MAPPING_LINEAR_STEEP, - "fast_fused_topk_merge: mapping_mode=", mapping_mode, - " not supported. Valid: POWER(3), ASINH(6), LOG1P(7), ERF(9), " - "TANH(10), SUBTRACT(11), EXP_STRETCH(13), SHIFT_POW2(15), " - "SHIFT_POW3(16), LINEAR_STEEP(17)."); - - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - const float mp = static_cast(mapping_power); - - // Dynamic smem must fit whichever phase is larger: - // Phase 1: chunk_size floats + chunk_size bytes. - // Phase 2: num_chunks*topk_val * (float + int32). - const size_t p1_bytes = static_cast(chunk_size) * sizeof(float) - + ((static_cast(chunk_size) + 15) & ~size_t(15)); - const size_t p2_bytes = static_cast(num_chunks) * - static_cast(topk_val) * - (sizeof(float) + sizeof(int32_t)); - const size_t smem_bytes = p1_bytes > p2_bytes ? p1_bytes : p2_bytes; - TORCH_CHECK(smem_bytes <= kMaxDynSmem, - "fast_fused_topk_merge: smem ", smem_bytes, - " > ceiling ", kMaxDynSmem); - - // Per-call workspace for the [batch, N, K] partial top-K. at::empty - // hits the caching allocator (no cudaMalloc in the hot path after - // warmup). The done-counter lives in __device__ memory — no memset. - auto opts_f32 = at::TensorOptions().device(score.device()).dtype(at::kFloat); - auto opts_i32 = at::TensorOptions().device(score.device()).dtype(at::kInt); - const int64_t ws_elems = batch_size * num_chunks * topk_val; - at::Tensor partial_keys = at::empty({ws_elems}, opts_f32); - at::Tensor partial_idx = at::empty({ws_elems}, opts_i32); - - dim3 grid(static_cast(batch_size), - static_cast(num_chunks)); - dim3 block(kThreadsPerBlock); - - #define LAUNCH(DTYPE, PTR_EXPR, MODE_VAL) \ - do { \ - setup_kernel_smem_once, \ - kMaxDynSmem>(); \ - TopK_Parallel_Kernel \ - <<>>( \ - PTR_EXPR, \ - global_topk_indices.data_ptr(), \ - partial_keys.data_ptr(), \ - partial_idx.data_ptr(), \ - static_cast(num_chunks), \ - static_cast(chunk_size), \ - static_cast(topk_val), \ - mp); \ - } while (0) - - #define DISPATCH_MODE(DTYPE, PTR_EXPR) \ - do { \ - switch (mapping_mode) { \ - case MAPPING_POWER: LAUNCH(DTYPE, PTR_EXPR, MAPPING_POWER); break; \ - case MAPPING_ASINH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ASINH); break; \ - case MAPPING_LOG1P: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LOG1P); break; \ - case MAPPING_ERF: LAUNCH(DTYPE, PTR_EXPR, MAPPING_ERF); break; \ - case MAPPING_TANH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_TANH); break; \ - case MAPPING_SUBTRACT: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SUBTRACT); break; \ - case MAPPING_EXP_STRETCH: LAUNCH(DTYPE, PTR_EXPR, MAPPING_EXP_STRETCH); break; \ - case MAPPING_SHIFT_POW2: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW2); break; \ - case MAPPING_SHIFT_POW3: LAUNCH(DTYPE, PTR_EXPR, MAPPING_SHIFT_POW3); break; \ - case MAPPING_LINEAR_STEEP: LAUNCH(DTYPE, PTR_EXPR, MAPPING_LINEAR_STEEP); break; \ - default: TORCH_CHECK(false, "unreachable mode"); \ - } \ - } while (0) - - if (score.scalar_type() == at::ScalarType::BFloat16) { - DISPATCH_MODE(__nv_bfloat16, - reinterpret_cast<__nv_bfloat16*>(score.data_ptr())); - } else if (score.scalar_type() == at::ScalarType::Float) { - DISPATCH_MODE(float, score.data_ptr()); - } else { - TORCH_CHECK(false, "fast_fused_topk_merge: unsupported dtype ", - score.scalar_type()); - } - - #undef DISPATCH_MODE - #undef LAUNCH - - const auto rc = cudaGetLastError(); - TORCH_CHECK(rc == cudaSuccess, - "fast_fused_topk_merge kernel failed: ", ::cudaGetErrorString(rc)); -} diff --git a/examples/run_distribution_analysis.sh b/examples/archived/run_distribution_analysis.sh similarity index 100% rename from examples/run_distribution_analysis.sh rename to examples/archived/run_distribution_analysis.sh diff --git a/examples/verify_algo_topk_mapping.sh b/examples/archived/verify_algo_topk_mapping.sh similarity index 100% rename from examples/verify_algo_topk_mapping.sh rename to examples/archived/verify_algo_topk_mapping.sh diff --git a/examples/plot_parallel_comparison.py b/examples/plot_parallel_comparison.py new file mode 100644 index 00000000..3d3f4883 --- /dev/null +++ b/examples/plot_parallel_comparison.py @@ -0,0 +1,384 @@ +#!/usr/bin/env python3 +"""Aggregate baseline / fused / adaptive (split) TopK latencies across remap +functions and emit CSV tables + matplotlib bar plots. + +Reads one remap_bench_*.json per (topk_val, num_splits) tag from the +directories produced by remap_function_bench_topk_parallel.sh and writes: + + results.csv long-form per-(K, splits, batch, mode, dist) rows + summary_topk.csv wide table per K (averaged across batch sizes) + summary_all.csv single combined wide table covering every K + comparison_topk.png bar plot per K (one bar group per mode) + comparison_all.png side-by-side per-K plots + +Input format: + --input "K=2048,splits=ns2=path/to/remap_bench_ns2.json" + --input "K=30,splits=auto=path/to/remap_bench_auto.json" + +The legacy "tag=path" input is also accepted; it lands in the all-K combined +plot but won't fill in the K/splits columns of results.csv. + +Usage: + python plot_parallel_comparison.py \ + --input "K=2048,splits=ns2=.../remap_bench_ns2.json" \ + --input "K=30,splits=auto=.../remap_bench_auto.json" \ + --output-dir /analysis [--emit-csv] [--emit-png] +""" +from __future__ import annotations + +import argparse +import csv +import json +import math +import re +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import matplotlib + +matplotlib.use("Agg") +import matplotlib.pyplot as plt + + +MODE_DISPLAY = { + 0: "None", + 3: "Power", + 4: "Log", + 6: "Asinh", + 7: "Log1p", + 9: "Erf", + 10: "Tanh", + 11: "Subtract", + 13: "ExpStretch", + 15: "ShiftPow2", + 16: "ShiftPow3", + 17: "LinearSteep", + 18: "HalfSquare", + 19: "HalfCube", +} + + +# --- input parsing ----------------------------------------------------------- + + +def _parse_input_spec(spec: str) -> Tuple[str, dict, Path]: + """Parse "K=2048,splits=ns2=path" -> ("K=2048,splits=ns2", {"K": "2048", + "splits": "ns2"}, Path("path")). + + Falls back to the legacy "tag=path" format if no comma-separated + key=value attrs precede the trailing "=path" segment. + """ + if "=" not in spec: + raise SystemExit(f"--input expects tag=path, got {spec!r}") + # Split on the LAST '=' that doesn't follow a comma (the path delim). + # Handle by scanning from the right. + eq_positions = [i for i, ch in enumerate(spec) if ch == "="] + path: str = "" + label: str = spec + for idx in reversed(eq_positions): + candidate_path = spec[idx + 1:] + if "/" in candidate_path or candidate_path.endswith(".json"): + label = spec[:idx] + path = candidate_path + break + if not path: + # last-resort: split on the rightmost '=' + label, path = spec.rsplit("=", 1) + + p = Path(path) + if not p.exists(): + raise SystemExit(f"{p} not found (input spec: {spec!r})") + + attrs: Dict[str, str] = {} + for segment in label.split(","): + segment = segment.strip() + if not segment or "=" not in segment: + continue + k, v = segment.split("=", 1) + attrs[k.strip()] = v.strip() + return label, attrs, p + + +def _load_rows(json_path: Path) -> List[dict]: + with open(json_path) as f: + data = json.load(f) + return data if isinstance(data, list) else data.get("results", []) + + +# --- aggregation ------------------------------------------------------------ + + +def _per_mode_rows(rows: List[dict], distribution: str | None = None): + """Yield one dict per (config, mode) pair so we can flatten to CSV.""" + for cfg in rows: + if distribution is not None and cfg.get("distribution") != distribution: + continue + cfg_keys = { + "batch_size": cfg.get("batch_size"), + "num_kv_heads": cfg.get("num_kv_heads"), + "seq_len": cfg.get("seq_len"), + "topk_val": cfg.get("topk_val"), + "distribution": cfg.get("distribution"), + "pages_per_seg": cfg.get("pages_per_seg"), + "head": cfg.get("head", "all"), + "baseline_ms": cfg.get("baseline_ms"), + } + for m in cfg.get("modes", []): + mode_id = m.get("mode") + if mode_id is None or mode_id < 0: + continue + yield { + **cfg_keys, + "mode": mode_id, + "mode_name": MODE_DISPLAY.get(mode_id, m.get("mode_name", f"m{mode_id}")), + "power": m.get("power"), + "fused_ms": m.get("fused_ms") + if m.get("fused_ms") is not None + else (m.get("topk_after_remap_ms") + if mode_id == 0 else None), + "parallel_ms": m.get("parallel_ms"), + "parallel_splits": m.get("parallel_splits"), + "remap_ms": m.get("remap_ms"), + "split_total_ms": m.get("split_total_ms"), + } + + +def _aggregate_per_mode(rows: List[dict], distribution: str = "real"): + """Return { mode -> {baseline, fused, parallel} } averaged across configs. + Falls back to all distributions if `distribution` is empty for these rows. + """ + used_dist = distribution + flat = list(_per_mode_rows(rows, distribution)) + if not flat: + used_dist = None + flat = list(_per_mode_rows(rows, None)) + out: Dict[int, Dict[str, List[float]]] = {} + for r in flat: + bucket = out.setdefault(r["mode"], + {"baseline_ms": [], "fused_ms": [], "parallel_ms": []}) + if r.get("baseline_ms") is not None: bucket["baseline_ms"].append(r["baseline_ms"]) + if r.get("fused_ms") is not None: bucket["fused_ms"].append(r["fused_ms"]) + if r.get("parallel_ms") is not None: bucket["parallel_ms"].append(r["parallel_ms"]) + summary = { + m: {k: (sum(v) / len(v) if v else float("nan")) for k, v in sub.items()} + for m, sub in out.items() + } + return summary, used_dist + + +# --- CSV writers ------------------------------------------------------------ + + +def _write_results_csv(records: List[dict], out_path: Path) -> None: + """Long-form per-(K, splits, batch, mode, dist) records → results.csv.""" + if not records: + out_path.write_text("") + return + fields = list({k for r in records for k in r.keys()}) + # Stable column order: identifiers first. + preferred = [ + "label", "K", "splits", "batch_size", "num_kv_heads", "seq_len", + "topk_val", "distribution", "pages_per_seg", "head", + "mode", "mode_name", "power", + "baseline_ms", "fused_ms", "parallel_ms", "parallel_splits", + "remap_ms", "split_total_ms", + ] + head = [c for c in preferred if c in fields] + tail = sorted(c for c in fields if c not in preferred) + cols = head + tail + with open(out_path, "w", newline="") as f: + w = csv.DictWriter(f, fieldnames=cols) + w.writeheader() + for r in records: + w.writerow({c: r.get(c, "") for c in cols}) + print(f" wrote {out_path} ({len(records)} rows)") + + +def _write_summary_csv(tag: str, summary: Dict[int, Dict[str, float]], + out_path: Path, *, attrs: Dict[str, str] | None = None) -> None: + attrs = attrs or {} + cols = ["K", "splits", "tag", "mode", "mode_name", + "baseline_ms", "fused_ms", "parallel_ms", + "fused_speedup", "parallel_speedup"] + with open(out_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for mode_id in sorted(summary): + s = summary[mode_id] + base = s.get("baseline_ms", float("nan")) + fused = s.get("fused_ms", float("nan")) + par = s.get("parallel_ms", float("nan")) + fs = (base / fused) if fused and not math.isnan(fused) else float("nan") + ps = (base / par) if par and not math.isnan(par) else float("nan") + w.writerow([ + attrs.get("K", ""), + attrs.get("splits", ""), + tag, + mode_id, + MODE_DISPLAY.get(mode_id, f"m{mode_id}"), + _csv_num(base), _csv_num(fused), _csv_num(par), + _csv_num(fs), _csv_num(ps), + ]) + print(f" wrote {out_path}") + + +def _write_summary_all_csv(summaries, attrs_by_tag, out_path: Path) -> None: + cols = ["tag", "K", "splits", "mode", "mode_name", + "baseline_ms", "fused_ms", "parallel_ms", + "fused_speedup", "parallel_speedup"] + with open(out_path, "w", newline="") as f: + w = csv.writer(f) + w.writerow(cols) + for tag, summary in summaries.items(): + attrs = attrs_by_tag.get(tag, {}) + for mode_id in sorted(summary): + s = summary[mode_id] + base = s.get("baseline_ms", float("nan")) + fused = s.get("fused_ms", float("nan")) + par = s.get("parallel_ms", float("nan")) + fs = (base / fused) if fused and not math.isnan(fused) else float("nan") + ps = (base / par) if par and not math.isnan(par) else float("nan") + w.writerow([ + tag, + attrs.get("K", ""), + attrs.get("splits", ""), + mode_id, + MODE_DISPLAY.get(mode_id, f"m{mode_id}"), + _csv_num(base), _csv_num(fused), _csv_num(par), + _csv_num(fs), _csv_num(ps), + ]) + print(f" wrote {out_path}") + + +def _csv_num(x): + if x is None or (isinstance(x, float) and math.isnan(x)): + return "" + return f"{x:.6f}" + + +# --- plotting --------------------------------------------------------------- + + +def _plot_bars(tag: str, summary: Dict[int, Dict[str, float]], out_path: Path) -> None: + modes = sorted(summary.keys()) + labels = [MODE_DISPLAY.get(m, f"m{m}") for m in modes] + base = [summary[m].get("baseline_ms", float("nan")) for m in modes] + fused = [summary[m].get("fused_ms", float("nan")) for m in modes] + par = [summary[m].get("parallel_ms", float("nan")) for m in modes] + + x = np.arange(len(modes)) + w = 0.27 + fig, ax = plt.subplots(figsize=(max(8, 0.85 * len(modes)), 5)) + ax.bar(x - w, base, w, label="Baseline (sglang)", color="#888888") + ax.bar(x, fused, w, label="Fused (sglang_fused)", color="#4C72B0") + ax.bar(x + w, par, w, label="Adaptive (output_adaptive)", color="#C44E52") + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=30, ha="right") + ax.set_ylabel("Latency (ms, lower is better)") + ax.set_title(f"TopK kernel latency — {tag}") + ax.grid(True, axis="y", linestyle="--", alpha=0.4) + ax.legend(loc="upper right") + fig.tight_layout() + fig.savefig(out_path, dpi=150) + plt.close(fig) + print(f" wrote {out_path}") + + +def _plot_combined(summaries: Dict[str, Dict[int, Dict[str, float]]], out_path: Path) -> None: + if not summaries: + return + tags = list(summaries.keys()) + fig, axes = plt.subplots(1, len(tags), figsize=(max(8, 7 * len(tags)), 5), sharey=False) + if len(tags) == 1: + axes = [axes] + for ax, tag in zip(axes, tags): + summary = summaries[tag] + modes = sorted(summary.keys()) + labels = [MODE_DISPLAY.get(m, f"m{m}") for m in modes] + base = [summary[m].get("baseline_ms", float("nan")) for m in modes] + fused = [summary[m].get("fused_ms", float("nan")) for m in modes] + par = [summary[m].get("parallel_ms", float("nan")) for m in modes] + x = np.arange(len(modes)) + w = 0.27 + ax.bar(x - w, base, w, label="Baseline", color="#888888") + ax.bar(x, fused, w, label="Fused", color="#4C72B0") + ax.bar(x + w, par, w, label="Adaptive", color="#C44E52") + ax.set_xticks(x) + ax.set_xticklabels(labels, rotation=30, ha="right") + ax.set_ylabel("Latency (ms)") + ax.set_title(tag) + ax.grid(True, axis="y", linestyle="--", alpha=0.4) + ax.legend(loc="upper right", fontsize=8) + fig.suptitle("Adaptive (split) vs Fused vs Baseline TopK", y=1.02) + fig.tight_layout() + fig.savefig(out_path, bbox_inches="tight", dpi=150) + plt.close(fig) + print(f" wrote {out_path}") + + +# --- main ------------------------------------------------------------------- + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--input", action="append", required=True, + help="=path/to/remap_bench_*.json (repeatable).") + p.add_argument("--output-dir", required=True) + p.add_argument("--distribution", default="real", + help="Distribution column to aggregate (falls back to all).") + p.add_argument("--emit-csv", action="store_true", + help="Write CSV tables (always on; flag kept for explicitness).") + p.add_argument("--emit-png", action="store_true", + help="Write PNG plots (always on; flag kept for explicitness).") + args = p.parse_args() + # Always emit both — flags are kept so the shell wrapper can document intent. + args.emit_csv = True + args.emit_png = True + + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + + summaries: Dict[str, Dict[int, Dict[str, float]]] = {} + attrs_by_tag: Dict[str, Dict[str, str]] = {} + long_form: List[dict] = [] + + for spec in args.input: + tag, attrs, path = _parse_input_spec(spec) + rows = _load_rows(path) + + # Long-form rows for results.csv + for r in _per_mode_rows(rows, args.distribution) or _per_mode_rows(rows, None): + long_form.append({ + "label": tag, + "K": attrs.get("K", ""), + "splits": attrs.get("splits", ""), + **r, + }) + + summary, _used_dist = _aggregate_per_mode(rows, distribution=args.distribution) + summaries[tag] = summary + attrs_by_tag[tag] = attrs + + if args.emit_csv: + _write_summary_csv(tag, summary, + out_dir / f"summary_{_safe(tag)}.csv", + attrs=attrs) + if args.emit_png: + _plot_bars(tag, summary, out_dir / f"comparison_{_safe(tag)}.png") + + if args.emit_csv: + _write_results_csv(long_form, out_dir / "results.csv") + _write_summary_all_csv(summaries, attrs_by_tag, out_dir / "summary_all.csv") + + if args.emit_png and len(summaries) > 1: + _plot_combined(summaries, out_dir / "comparison_all.png") + + +def _safe(tag: str) -> str: + """Make a tag safe for use in a filename (strip ',', '=', '/').""" + return re.sub(r"[^A-Za-z0-9._-]", "_", tag) + + +if __name__ == "__main__": + main() diff --git a/examples/profile_in_docker.sh b/examples/profile_in_docker.sh deleted file mode 100755 index a64606ff..00000000 --- a/examples/profile_in_docker.sh +++ /dev/null @@ -1,181 +0,0 @@ -#!/usr/bin/env bash -# ============================================================ -# Run examples/profile_parallel_vs_fused.sh inside an NVIDIA -# CUDA devel container so we can enable profiling without -# touching the host's RmProfilingAdminOnly=1 setting. -# -# Key idea: -# - The container has `ncu` bundled with the CUDA toolkit. -# - --cap-add=SYS_ADMIN gives the container the capability -# CUPTI needs to access perf counters, so ncu works -# regardless of the host's nvidia-driver profiling restriction. -# - We mount the host's uv venv and the project, so there's -# no Python/pytorch install inside the container — the host -# venv's python is used directly. -# -# Image: -# Defaults to an NGC public CUDA devel image. For B200 (Blackwell / -# sm_100) you need CUDA ≥ 12.8 and ncu ≥ 2024.3; CUDA 13.0+ covers -# that. Override with NCU_IMAGE if you prefer a specific tag. -# -# Usage: -# bash examples/profile_in_docker.sh # defaults -# GPU=2 NUM_SPLITS=2 bash examples/profile_in_docker.sh -# NCU_IMAGE=nvcr.io/nvidia/pytorch:25.03-py3 bash examples/profile_in_docker.sh -# ============================================================ -set -euo pipefail - -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PROJECT_DIR="$(cd "${SCRIPT_DIR}/.." && pwd)" - -# Host venv to reuse. uv venvs have their own python binary under -# $VENV/bin/python3.x that is glibc/libstdc++-compatible with the -# container when using NGC Ubuntu 22.04 / 24.04 images. -VENV_DIR="${VENV_DIR:-/home/zhuominc/xinrui_projects/uv_env/vortex}" - -# NGC CUDA devel image on Ubuntu. Has /usr/local/cuda/bin/ncu bundled. -# 13.0.1-devel-ubuntu22.04 is public (no NGC login needed), supports -# B200, and matches the host's CUDA 13.x driver ABI. -# -# Alternatives: -# nvcr.io/nvidia/cuda:13.0.1-devel-ubuntu24.04 # newer base -# nvcr.io/nvidia/pytorch:25.03-py3 # if you don't want to -# # reuse the host venv -# Host is Ubuntu 24.04 + Python 3.12 (the uv venv points to /usr/bin/python3.12). -# Match the container to that so the venv's symlinked python resolves to a -# compatible interpreter inside the container. -NCU_IMAGE="${NCU_IMAGE:-nvcr.io/nvidia/cuda:13.0.1-devel-ubuntu24.04}" - -# Pass-through env vars for the inner profile script. Defaults match -# examples/profile_parallel_vs_fused.sh. -GPU="${GPU:-7}" -EFF_BS="${EFF_BS:-1}" -NUM_SPLITS="${NUM_SPLITS:-2}" -POWER="${POWER:--1.0}" -WARMUP="${WARMUP:-20}" -ITERS="${ITERS:-1}" -SECTION_SET="${SECTION_SET:-full}" - -# Inside the container, these mount points give the profile script the -# same absolute paths it sees on the host (so the script doesn't need -# to be container-aware). -MOUNT_ROOT="/home/zhuominc/xinrui_projects" - -if [ ! -d "${VENV_DIR}" ]; then - echo "ERROR: VENV_DIR not found: ${VENV_DIR}" - echo " Set VENV_DIR=/path/to/venv or install the venv." - exit 1 -fi - -VENV_PY="$(ls "${VENV_DIR}"/bin/python* 2>/dev/null | head -1 || true)" -if [ -z "${VENV_PY}" ]; then - echo "ERROR: no python found under ${VENV_DIR}/bin/" - exit 1 -fi - -echo "============================================================" -echo "Docker-wrapped ncu profiling" -echo " image: ${NCU_IMAGE}" -echo " venv: ${VENV_DIR} (python=${VENV_PY##*/})" -echo " project: ${PROJECT_DIR}" -echo " GPU: ${GPU}" -echo " eff_bs: ${EFF_BS}" -echo " num_splits: ${NUM_SPLITS}" -echo " power: ${POWER}" -echo " warmup/iters: ${WARMUP}/${ITERS}" -echo " section set: ${SECTION_SET}" -echo "============================================================" - -# Pull the image up-front (so the output during the run isn't -# interleaved with pull progress). `|| true` — pull is optional; -# if the image is already local, docker run will use the cached copy. -docker pull "${NCU_IMAGE}" || true - -# Run the profile script inside the container. -# -# --gpus all : give the container access to all GPUs -# (CUDA_VISIBLE_DEVICES inside the script -# narrows it down to GPU ${GPU}). -# --cap-add=SYS_ADMIN : lets CUPTI access perf counters without -# touching host profiling restrictions. -# --security-opt seccomp=unconfined : CUPTI needs a few syscalls -# the default seccomp profile blocks. -# --network host : not strictly required, but keeps pip/uv -# network access working if you ever add -# pip-install steps. -# --user $(id -u):$(id -g) -# : write output files owned by your user, -# not root. -# -v /etc/passwd:/etc/passwd:ro -v /etc/group:/etc/group:ro -# : so the uid inside resolves to a real -# user (helps some tools, harmless otherwise). -# -v ${MOUNT_ROOT}:${MOUNT_ROOT} -# : mount the whole xinrui_projects tree so -# both the project and the venv are visible -# at their host paths. -# -e PYTHONPATH=... : add the venv's site-packages explicitly -# so `python3 -c 'import vortex_torch_C'` -# resolves even without activate. -# -e PATH=... : put the venv's bin ahead of /usr/local/cuda/bin -# so `python` is the venv python, and keep ncu -# reachable. -# When invoked via `sudo`, `id -u` returns 0 (root). Prefer SUDO_UID/ -# SUDO_GID so the final chown hands results back to the real user, -# not root. Fall back to the effective uid/gid otherwise. -HOST_UID="${SUDO_UID:-$(id -u)}" -HOST_GID="${SUDO_GID:-$(id -g)}" - -docker run --rm \ - --gpus all \ - --cap-add=SYS_ADMIN \ - --security-opt seccomp=unconfined \ - --network host \ - --ipc=host \ - -e DISPLAY="${DISPLAY:-}" \ - -v /tmp/.X11-unix:/tmp/.X11-unix \ - -v "${MOUNT_ROOT}:${MOUNT_ROOT}" \ - -w "${PROJECT_DIR}" \ - -e GPU="${GPU}" \ - -e EFF_BS="${EFF_BS}" \ - -e NUM_SPLITS="${NUM_SPLITS}" \ - -e POWER="${POWER}" \ - -e WARMUP="${WARMUP}" \ - -e ITERS="${ITERS}" \ - -e SECTION_SET="${SECTION_SET}" \ - -e NCU="/usr/local/cuda/bin/ncu" \ - -e HOST_UID="${HOST_UID}" \ - -e HOST_GID="${HOST_GID}" \ - -e PATH="${VENV_DIR}/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" \ - -e LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH:-}" \ - "${NCU_IMAGE}" \ - bash -lc ' - set -e - # Ubuntu 24.04 base may not ship python3.12 in the CUDA devel image. - # Install it idempotently; this is ~2s if missing and skipped otherwise. - if [ ! -x /usr/bin/python3.12 ]; then - echo "--- installing python3.12 in container ---" - export DEBIAN_FRONTEND=noninteractive - apt-get update -qq - apt-get install -y --no-install-recommends python3.12 >/dev/null - fi - echo "--- container environment ---" - echo "python: $(readlink -f "$(which python)") ($(python --version 2>&1))" - echo "ncu: $(which ncu)" - ncu --version 2>&1 | head -2 - nvidia-smi -L - python -c "import torch; print(\"torch: \", torch.__version__, \"cuda:\", torch.version.cuda)" - python -c "import vortex_torch_C; print(\"vortex_torch_C import OK\")" - echo "-----------------------------" - bash examples/profile_parallel_vs_fused.sh - # Hand output files back to the host user (we ran as root so apt - # could install python3.12). - chown -R "${HOST_UID}:${HOST_GID}" examples/results 2>/dev/null || true - ' - -echo "" -echo "============================================================" -echo "Docker profiling run complete." -echo "Reports are under: ${PROJECT_DIR}/examples/results/" -echo "(same path as the direct script — you own the files since we" -echo " ran the container as your uid)." -echo "============================================================" diff --git a/examples/profile_parallel_vs_fused_ncu.sh b/examples/profile_parallel_vs_fused_ncu.sh deleted file mode 100755 index bf9baaf1..00000000 --- a/examples/profile_parallel_vs_fused_ncu.sh +++ /dev/null @@ -1,277 +0,0 @@ -#!/usr/bin/env bash -# ============================================================ -# Nsight Compute profiling script for the parallel vs fused -# TopK kernels. -# -# Profiles both: -# - TopKOutput_Fused_Kernel (csrc/topk_sglang.cu) -# - TopKOutput_Parallel_Kernel (csrc/topk_sglang_parallel.cu) -# -# With both remap functions the user cares about: -# - mode 15: MAPPING_SHIFT_POW2 -# - mode 16: MAPPING_SHIFT_POW3 -# -# And both configs: -# - A: topk=2048, pages_per_seg=32K (topk=2k from 32k) -# - B: topk=30, pages_per_seg=2K (topk=30 from 2k) -# -# Produces one .ncu-rep per (kernel × mode × config). Open with -# the Nsight Compute GUI for an interactive comparison, or dump on -# the CLI with `ncu --import .ncu-rep --page details`. -# -# Usage: -# bash examples/profile_parallel_vs_fused.sh # defaults -# GPU=4 EFF_BS=1 bash examples/profile_parallel_vs_fused.sh # small-batch case -# GPU=4 EFF_BS=32 bash examples/profile_parallel_vs_fused.sh # saturated case -# GPU=4 NUM_SPLITS=2 bash examples/profile_parallel_vs_fused.sh -# -# Requires `ncu` on PATH (part of the CUDA toolkit). On most systems -# accessing performance counters requires either: -# - root/sudo, or -# - `echo 1 | sudo tee /proc/driver/nvidia/params` (temporary), or -# - setting NVreg_RestrictProfilingToAdminUsers=0 in the nvidia driver. -# If ncu reports "ERR_NVGPUCTRPERM" you'll need one of the above. -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PY_DRIVER="${SCRIPT_DIR}/../benchmarks/profile_parallel_vs_fused.py" - -# ── Defaults ────────────────────────────────────────────────── -GPU=${GPU:-7} -EFF_BS=${EFF_BS:-1} # eff_batch_size = batch_size * num_kv_heads -NUM_SPLITS=${NUM_SPLITS:-2} # only used by the parallel kernel -POWER=${POWER:--1.0} # pivot p for shift_pow{2,3} -WARMUP=${WARMUP:-20} # matching-kernel warmup launches (ncu skips) -ITERS=${ITERS:-1} # matching-kernel profiled launches (ncu captures) -SECTION_SET=${SECTION_SET:-full} # ncu section set: "full", "basic", or named sections - -# Profiling robustness knobs for shared GPUs / CUDA 13 systems. -# --replay-mode application: re-run the entire process to collect each -# counter pass, instead of replaying individual -# kernels. Fixes "Failed to prepare kernel" on -# systems where kernel replay hits PMU conflicts. -# --clock-control none : don't try to lock GPU clocks (requires admin on -# shared GPUs; without this, "Unknown error on -# device 0" is common). -# --cache-control none : don't flush L1/L2 between passes (also needs -# admin on shared systems). -# Override with NCU_EXTRA_FLAGS="..." if you need a different combination. -NCU_EXTRA_FLAGS=${NCU_EXTRA_FLAGS:-"--replay-mode application --clock-control none --cache-control none"} - -# DIAG=1 bash profile_parallel_vs_fused.sh → run one tiny ncu probe to -# verify profiling works before doing the full sweep. -DIAG=${DIAG:-0} - -# ── ncu command ─────────────────────────────────────────────── -NCU=${NCU:-ncu} -command -v "${NCU}" >/dev/null 2>&1 || { - echo "ERROR: '${NCU}' not found on PATH. Install Nsight Compute (part of CUDA Toolkit)" - echo " or set NCU=/path/to/ncu and re-run." - exit 1 -} - -# The templated kernels end up with mangled names like -# _Z25TopKOutput_Fused_KernelI13__nv_bfloat16ILi15EEEvPKT_... -# ncu supports --kernel-name regex: which matches on the -# demangled signature. Using "TopKOutput_Fused_Kernel" and -# "TopKOutput_Parallel_Kernel" as the regex selects all template -# instantiations of each kernel but nothing else. -FUSED_REGEX="regex:TopKOutput_Fused_Kernel" -PARALLEL_REGEX="regex:TopKOutput_Parallel_Kernel" - -# ── Output dir ──────────────────────────────────────────────── -TIMESTAMP=$(date +%Y%m%d_%H%M%S) -OUT_DIR="${SCRIPT_DIR}/results/ncu_parallel_vs_fused_${TIMESTAMP}" -mkdir -p "${OUT_DIR}" - -echo "============================================================" -echo "Nsight Compute profile: parallel vs fused TopK" -echo " GPU: ${GPU}" -echo " eff_bs: ${EFF_BS}" -echo " num_splits: ${NUM_SPLITS} (parallel kernel only)" -echo " power (p): ${POWER} (for shift_pow{2,3})" -echo " warmup: ${WARMUP} (matching-kernel launches skipped by ncu)" -echo " iters: ${ITERS} (matching-kernel launches captured)" -echo " sections: --set ${SECTION_SET}" -echo " extra ncu flags:${NCU_EXTRA_FLAGS}" -echo " output dir: ${OUT_DIR}" -echo "============================================================" - -# ── Diagnostic probe ───────────────────────────────────────── -# Verifies that ncu can attach and collect at least one section on -# this GPU before we burn time on the full sweep. Uses --set basic -# which is the cheapest section set. If this fails, see the -# TROUBLESHOOTING block that the script prints on error. -run_diag() { - echo "" - echo ">>> Diagnostic probe: can ncu attach at all?" - local out="${OUT_DIR}/diag.ncu-rep" - set +e - CUDA_VISIBLE_DEVICES="${GPU}" "${NCU}" \ - --force-overwrite \ - --target-processes all \ - --kernel-name "${FUSED_REGEX}" \ - --launch-skip "${WARMUP}" \ - --launch-count 1 \ - --set basic \ - ${NCU_EXTRA_FLAGS} \ - --export "${out}" \ - python "${PY_DRIVER}" \ - --config A --eff-bs 1 --mode 15 --power "${POWER}" \ - --num-splits "${NUM_SPLITS}" --kernel fused \ - --warmup "${WARMUP}" --iters 1 - local rc=$? - set -e - if [ ${rc} -ne 0 ]; then - cat <<'EOF' - -============================================================ -TROUBLESHOOTING "Failed to prepare kernel for profiling" -============================================================ - 1) Is another process using GPU ${GPU}? Check: - nvidia-smi - If yes, pick an idle GPU: - GPU=0 bash examples/profile_parallel_vs_fused.sh - - 2) Perf counters may be locked to admin. Try as root: - sudo -E bash examples/profile_parallel_vs_fused.sh - - Or permanently unlock (admin, persists until reboot): - sudo sh -c 'echo 1 > /proc/driver/nvidia/params' - - Or permanently in the driver (needs reboot): - Add NVreg_RestrictProfilingToAdminUsers=0 to - /etc/modprobe.d/nvidia.conf - - 3) MPS or another profiler (CUPTI, Nsight Systems, etc.) - may be running. Kill with: - echo quit | nvidia-cuda-mps-control - and verify nothing else is profiling. - - 4) On H100 with MIG: profiling across MIG slices is - restricted. Use a full-device GPU. - - 5) Try a smaller ncu configuration first: - NCU_EXTRA_FLAGS="--replay-mode application --clock-control none --cache-control none --metrics sm__cycles_elapsed.avg" \ - bash examples/profile_parallel_vs_fused.sh - - 6) CUDA 13.2 vs PyTorch-13.0 mismatch is sometimes flagged - by ncu. Update ncu to match CUDA 13.2, or use the ncu - shipped with CUDA 13.2: - NCU=/usr/local/cuda-13.2/bin/ncu bash ... - -============================================================ -EOF - echo "Diagnostic probe failed (exit ${rc}). See troubleshooting above." - exit ${rc} - fi - echo ">>> Diagnostic probe OK. Proceeding with full sweep." -} - -if [ "${DIAG}" = "1" ]; then - run_diag - exit 0 -fi - -# Always run a cheap probe first so full-sweep failures are caught early -# before we've spent minutes on the heavy --set full passes. -run_diag - -# ── Helper: run one ncu profile ────────────────────────────── -# tag : name used for the output file -# kernel : "fused" or "parallel" (drives Python driver dispatch) -# regex : ncu --kernel-name filter -# config : "A" or "B" -# mode : 15 or 16 -run_ncu() { - local tag="$1" - local kernel="$2" - local regex="$3" - local config="$4" - local mode="$5" - - local out="${ - - - - }/${tag}.ncu-rep" - - echo "" - echo ">>> ${tag}" - - # --launch-skip/--launch-count count ONLY kernels matching - # --kernel-name, so setup kernels (torch.randn, etc.) don't - # pollute the offsets. With --launch-skip=${WARMUP} and the - # Python driver doing ${WARMUP} warmup + ${ITERS} profiled - # calls, ncu captures exactly the profiled ones. - CUDA_VISIBLE_DEVICES="${GPU}" "${NCU}" \ - --force-overwrite \ - --target-processes all \ - --kernel-name "${regex}" \ - --launch-skip "${WARMUP}" \ - --launch-count "${ITERS}" \ - --set "${SECTION_SET}" \ - ${NCU_EXTRA_FLAGS} \ - --export "${out}" \ - python "${PY_DRIVER}" \ - --config "${config}" \ - --eff-bs "${EFF_BS}" \ - --mode "${mode}" \ - --power "${POWER}" \ - --num-splits "${NUM_SPLITS}" \ - --kernel "${kernel}" \ - --warmup "${WARMUP}" \ - --iters "${ITERS}" - - echo " report: ${out}" -} - -# ── Sweep ──────────────────────────────────────────────────── -for MODE in 15 16; do - if [ "${MODE}" -eq 15 ]; then MODE_TAG="SP2"; else MODE_TAG="SP3"; fi - for CONFIG in A B; do - run_ncu "fused_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}" \ - "fused" "${FUSED_REGEX}" "${CONFIG}" "${MODE}" - run_ncu "parallel_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}_ns${NUM_SPLITS}" \ - "parallel" "${PARALLEL_REGEX}" "${CONFIG}" "${MODE}" - done -done - -echo "" -echo "============================================================" -echo "All profiles done. Reports saved under:" -echo " ${OUT_DIR}" -echo "" -echo "Interactive analysis (recommended):" -echo " ncu-ui ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.ncu-rep" -echo "" -echo "CLI summary, one kernel at a time:" -echo " ncu --import ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.ncu-rep --page details" -echo "" -echo "Side-by-side diff (CLI):" -echo " ncu --import ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.ncu-rep \\" -echo " --import ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.ncu-rep \\" -echo " --page details --csv > ${OUT_DIR}/compare_SP2_cfgA.csv" -echo "" -echo "What to look at (to pinpoint the overhead vs fused):" -echo " * Section 'GPU Speed Of Light Throughput'" -echo " → SM %, Memory %, which one is the bound?" -echo " * Section 'Launch Statistics'" -echo " → Grid/Block size, Dynamic Shared Mem per block" -echo " * Section 'Occupancy'" -echo " → Theoretical vs achieved; limit (smem / regs / blocks/SM)" -echo " * Section 'Warp State Statistics'" -echo " → Stall breakdown: Stall Barrier (__syncthreads/__threadfence)," -echo " Stall Long Scoreboard (global memory), Stall Short Scoreboard" -echo " (smem/atomic)" -echo " * Section 'Memory Workload Analysis'" -echo " → L2/Device throughput, atomic traffic, smem bank conflicts" -echo " * Section 'Compute Workload Analysis'" -echo " → Pipe utilisation (FMA / ALU / FP64)" -echo "" -echo "Likely suspects for the parallel-vs-fused gap:" -echo " - Occupancy limited by the large dynamic smem (kSmem + chunk_bytes)" -echo " - Stall Barrier dominating due to the __threadfence before atomicInc" -echo " - Phase 1 CTAs repeat Stage-2 refinement that fused does only once" -echo " → visible as 'Pipe Utilisation ALU / Special' for integer radix ops" -echo "============================================================" diff --git a/examples/profile_parallel_vs_fused_nsys.sh b/examples/profile_parallel_vs_fused_nsys.sh deleted file mode 100755 index 3d64519a..00000000 --- a/examples/profile_parallel_vs_fused_nsys.sh +++ /dev/null @@ -1,211 +0,0 @@ -#!/usr/bin/env bash -# ============================================================ -# Nsight Systems (nsys) profiling — timeline view of the parallel -# vs fused TopK kernels. -# -# Why nsys and not ncu here: -# ncu needs SM-level perf counters (sm__*), which on this box are -# gated by the nvidia driver's RmProfilingAdminOnly flag — and we -# have no sudo. nsys uses CUPTI API/activity tracing and kernel -# timing, which do NOT require admin. That's enough to answer the -# "where does the 6-8us overhead come from" question, because we -# get per-kernel durations, gaps on the stream, memcpy/memset -# traffic, and NVTX range timing. -# -# Profiles both: -# - TopKOutput_Fused_Kernel (csrc/topk_sglang.cu) -# - TopKOutput_Parallel_Kernel (csrc/topk_sglang_parallel.cu) -# -# For each of mode 15 (SHIFT_POW2), mode 16 (SHIFT_POW3) and both -# configs A (topk=2048 pages=32K) and B (topk=30 pages=2K). -# -# Produces one .nsys-rep per (kernel × mode × config). Open with: -# nsys-ui .nsys-rep -# or dump CLI summaries with: -# nsys stats .nsys-rep -# -# Usage: -# bash examples/profile_parallel_vs_fused_nsys.sh # defaults -# GPU=7 NUM_SPLITS=2 bash examples/profile_parallel_vs_fused_nsys.sh -# ITERS=50 bash examples/profile_parallel_vs_fused_nsys.sh # more samples -# ============================================================ -set -euo pipefail -SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -PY_DRIVER="${SCRIPT_DIR}/../benchmarks/profile_parallel_vs_fused.py" - -# ── Defaults ────────────────────────────────────────────────── -GPU=${GPU:-7} -EFF_BS=${EFF_BS:-1} -NUM_SPLITS=${NUM_SPLITS:-2} -POWER=${POWER:--1.0} -WARMUP=${WARMUP:-20} -# For nsys we want *many* iterations so the per-kernel timing is -# statistically meaningful and the timeline is readable. -ITERS=${ITERS:-50} - -# Prefer the CUDA-13 toolchain's nsys (matches the torch CUDA ABI). -NSYS=${NSYS:-$(command -v nsys || echo /usr/local/cuda/bin/nsys)} -if [ ! -x "${NSYS}" ]; then - echo "ERROR: nsys not found. Tried: ${NSYS}" - echo " Set NSYS=/path/to/nsys manually." - exit 1 -fi - -TIMESTAMP=$(date +%Y%m%d_%H%M%S) -OUT_DIR="${SCRIPT_DIR}/results/nsys_parallel_vs_fused_${TIMESTAMP}" -mkdir -p "${OUT_DIR}" - -# nsys writes intermediate files under $TMPDIR/nvidia/nsight_systems. -# On shared systems /tmp/nvidia is often owned by another user who -# created it first, and we can't write there. Redirect to a -# user-writable cache dir. -export TMPDIR="${TMPDIR:-${HOME}/.cache/nsys_tmp}" -mkdir -p "${TMPDIR}" - -echo "============================================================" -echo "Nsight Systems profile: parallel vs fused TopK" -echo " GPU: ${GPU}" -echo " eff_bs: ${EFF_BS}" -echo " num_splits: ${NUM_SPLITS}" -echo " power (p): ${POWER}" -echo " warmup: ${WARMUP}" -echo " iters: ${ITERS} (profiled launches)" -echo " nsys binary: ${NSYS}" -echo " output dir: ${OUT_DIR}" -echo "============================================================" -"${NSYS}" --version 2>&1 | head -2 - -# ── Helper: run one nsys profile ───────────────────────────── -run_nsys() { - local tag="$1" - local kernel="$2" - local config="$3" - local mode="$4" - - local out="${OUT_DIR}/${tag}" - - echo "" - echo ">>> ${tag}" - - # --trace cuda,nvtx : CUDA API/runtime + NVTX ranges. NVTX - # stays on so the timeline still shows - # where the profiled region begins. - # --sample none / --cpuctxsw none: skip CPU callstack sampling and - # context-switch tracing — both admin-gated - # on this box and we don't need them. - # --cuda-memory-usage true: log cudaMalloc/cudaFree/cudaMemset so we - # can see if at::empty / at::zeros costs - # anything on the hot path. - # - # Capture-range flags intentionally OMITTED. On some nsys builds - # --capture-range=nvtx silently yields "No reports were generated" - # when the ranges don't line up exactly; profiling the whole run - # is more robust and the warmup is easy to filter out later - # (NVTX range "profile-*" tags the profiled region in nsys stats). - CUDA_VISIBLE_DEVICES="${GPU}" "${NSYS}" profile \ - --output "${out}" \ - --force-overwrite true \ - --trace cuda,nvtx \ - --sample none \ - --cpuctxsw none \ - --cuda-memory-usage true \ - python "${PY_DRIVER}" \ - --config "${config}" \ - --eff-bs "${EFF_BS}" \ - --mode "${mode}" \ - --power "${POWER}" \ - --num-splits "${NUM_SPLITS}" \ - --kernel "${kernel}" \ - --warmup "${WARMUP}" \ - --iters "${ITERS}" - - echo " report: ${out}.nsys-rep" -} - -# ── Sweep ──────────────────────────────────────────────────── -for MODE in 15 16; do - if [ "${MODE}" -eq 15 ]; then MODE_TAG="SP2"; else MODE_TAG="SP3"; fi - for CONFIG in A B; do - run_nsys "fused_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}" \ - "fused" "${CONFIG}" "${MODE}" - run_nsys "parallel_${MODE_TAG}_cfg${CONFIG}_eff${EFF_BS}_ns${NUM_SPLITS}" \ - "parallel" "${CONFIG}" "${MODE}" - done -done - -# ── Auto-dump CLI summaries for every report ───────────────── -# `nsys stats` produces text tables that are immediately readable -# and answer most "where did the time go" questions without needing -# the GUI. We dump the most useful ones for every report and stash -# them alongside. -echo "" -echo "============================================================" -echo "Dumping text summaries ('nsys stats') for every report..." -echo "============================================================" -for rep in "${OUT_DIR}"/*.nsys-rep; do - name="$(basename "${rep}" .nsys-rep)" - echo "" - echo ">>> summary for ${name}" - summary="${OUT_DIR}/${name}.summary.txt" - { - echo "### ${name}" - echo "" - echo "## cuda_api_sum: CUDA runtime API call distribution" - echo "## (count, avg, med, min, max of cudaLaunchKernel / cudaMalloc / etc.)" - "${NSYS}" stats --report cuda_api_sum --format table "${rep}" 2>&1 || true - echo "" - echo "## cuda_gpu_kern_sum: per-kernel GPU duration stats" - echo "## (mean/median/std/min/max duration per kernel name, with instance count)" - "${NSYS}" stats --report cuda_gpu_kern_sum --format table "${rep}" 2>&1 || true - echo "" - echo "## cuda_gpu_mem_size_sum: memcpy / memset by size" - echo "## (expect 0 memset entries for parallel — no at::zeros on the hot path)" - "${NSYS}" stats --report cuda_gpu_mem_size_sum --format table "${rep}" 2>&1 || true - echo "" - echo "## cuda_gpu_mem_time_sum: memcpy / memset by time" - "${NSYS}" stats --report cuda_gpu_mem_time_sum --format table "${rep}" 2>&1 || true - echo "" - echo "## cuda_kern_exec_sum: kernel launch→exec latency" - echo "## (host-side cudaLaunchKernel cost separated from GPU exec cost)" - "${NSYS}" stats --report cuda_kern_exec_sum --format table "${rep}" 2>&1 || true - echo "" - echo "## nvtx_pushpop_sum: NVTX ranges (the 'profile-*' wrapped region)" - "${NSYS}" stats --report nvtx_pushpop_sum --format table "${rep}" 2>&1 || true - } > "${summary}" 2>&1 - echo " saved: ${summary}" -done - -echo "" -echo "============================================================" -echo "Reports saved to: ${OUT_DIR}" -echo "" -echo "Quick read — compare fused vs parallel summaries side-by-side:" -echo "" -echo " diff -y --width=200 \\" -echo " ${OUT_DIR}/fused_SP2_cfgA_eff${EFF_BS}.summary.txt \\" -echo " ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.summary.txt \\" -echo " | less" -echo "" -echo "Interactive timeline (if you have X11/SSH forwarding):" -echo " nsys-ui ${OUT_DIR}/parallel_SP2_cfgA_eff${EFF_BS}_ns${NUM_SPLITS}.nsys-rep" -echo "" -echo "What to look for (to nail the overhead vs fused):" -echo " * 'cuda_gpu_kern_sum' mean duration for each kernel" -echo " → fused is one kernel × (WARMUP+ITERS), parallel is one kernel × (WARMUP+ITERS)" -echo " (single-kernel design). Mean duration difference = the GPU work" -echo " gap (Stage-1 savings minus merge cost)." -echo " * 'cuda_api_sum' cudaLaunchKernel / cudaMalloc / cudaFree counts" -echo " → if parallel shows more launches than fused, there's an unexpected" -echo " extra kernel. Also watch the time spent in cudaLaunchKernel." -echo " * 'cuda_gpu_mem_size_sum' cudaMemset entries" -echo " → should be zero for parallel now (__device__ counter removed" -echo " at::zeros). Any memset here IS overhead we need to explain." -echo " * 'cuda_kern_exec_sum'" -echo " → separates host-side cudaLaunchKernel latency from GPU kernel time." -echo " * 'nvtx_pushpop_sum' profile-* range duration / ${ITERS}" -echo " → wall-clock per-call including CPU-side overhead." -echo "" -echo "Timeline view (nsys-ui) additionally shows *gaps* between kernels" -echo "on the GPU stream — the cost of __threadfence + atomicInc barrier" -echo "shows up as a visible pause between Phase-1 work and the merge." -echo "============================================================" diff --git a/examples/remap_function_bench_topk_parallel.sh b/examples/remap_function_bench_topk_parallel.sh index 4a4e4c57..7be48b88 100755 --- a/examples/remap_function_bench_topk_parallel.sh +++ b/examples/remap_function_bench_topk_parallel.sh @@ -1,245 +1,156 @@ #!/usr/bin/env bash # ============================================================ -# Remap Function Benchmark — Parallel TopK variant. +# Three-way TopK kernel latency comparison for K=30. # -# Wraps bench_topk.py --remap-bench with --bench-parallel so the -# output table includes a "par_ms" column comparing the split+merge -# kernel (topk_output_sglang_parallel) against the single-CTA -# fused kernel. Also sweeps batch size and num_splits so the -# occupancy-vs-merge-overhead curve is visible. +# Compares (per (batch_size, pages)): +# topk.cu -> topk_output (CUB BlockRadixSort full sort) +# topk_sglang.cu -> topk_output_sglang + +# topk_output_sglang_fused (2-stage radix select) +# topk_sglang_merge.cu -> topk_output_adaptive (adaptive split SELECT32_SORT32) # -# Pipeline mirrors remap_function_bench_topk2028.sh: -# Step 1 — calibrate (can be skipped with --real-histograms) -# Step 2 — autotune per-mode hparams by fused-kernel latency -# Step 3 — remap bench, looped over NUM_SPLITS_SWEEP values +# Pages are varied by --seq-lens (with --page-size 1: pages == seq_len). +# Default sweep is the matrix the user requested: +# batch_sizes = {1, 2, 4, 8, 16} +# pages = {4096, 8192, 16384} +# topk = 30 # -# Usage: -# bash remap_function_bench_topk_parallel.sh --gpu 4 +# No calibration, no remap autotune, no model download — purely synthetic +# scores so the only variable is the kernel itself. # -# # Explicit batch-size sweep: -# bash remap_function_bench_topk_parallel.sh --gpu 4 \ -# --batch-sizes "1 2 4 8" --num-splits-sweep "auto 2 4 8" +# Usage: +# bash examples/remap_function_bench_topk_parallel.sh --gpu 0 +# bash examples/remap_function_bench_topk_parallel.sh --gpu 0 \ +# --batch-sizes "1 2 4 8 16" \ +# --seq-lens "4096 8192 16384 32768" # ============================================================ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" BENCH_DIR="${SCRIPT_DIR}/../benchmarks" -# ── Defaults ────────────────────────────────────────────────── -GPU_ID=7 -MODEL_NAME="Qwen/Qwen3-1.7B" -TOPK_VAL=2048 -MEM=0.7 -MAX_TOTAL_TOKENS=64768 -MIN_FREE_DISK_GB=22 -ALGO="block_sparse_attention" -SAMPLE_STRIDE=1 -SEQ_LEN=32768 -BLOCK_SIZE=1 +# ── Defaults (matrix the brief calls out) ───────────────────── +GPU_ID=0 +TOPK_VALS="30" BATCH_SIZES="1 2 4 8 16" +SEQ_LENS="4096 8192 16384" # pages-per-seg when page-size=1 NUM_KV_HEADS=8 -DISTRIBUTIONS="normal bucket_uniform" -# Modes excluding 1 (LUT_CDF) and 2 (Quantile) which are discarded. -MAPPING_MODES="0 3 6 7 9 10 11 13 15 16 17 18 19" -MAPPING_HPARAM=0.5 -REPEAT=100 +PAGE_SIZE=1 +RESERVED_BOS=1 +RESERVED_EOS=2 +DISTRIBUTIONS="normal" WARMUP=20 -# "auto" lets bench_topk.py pick via sqrt(pages/topk). Explicit ints -# pin a split count for A/B comparisons. -NUM_SPLITS_SWEEP="auto 2 4 8" -REAL_HISTOGRAMS="/var/tmp/zhuominc/vortex_torch/calibration/raw_histograms_qwen3-1.7B.npy" -SKIP_AUTOTUNE=0 -PINNED_AUTOTUNE_JSON="" +REPEAT=200 -# ── Parse arguments ─────────────────────────────────────────── +# ── Arg parsing ─────────────────────────────────────────────── while [[ $# -gt 0 ]]; do case "$1" in - --model-name) MODEL_NAME="$2"; shift 2 ;; - --topk-val) TOPK_VAL="$2"; shift 2 ;; - --mem) MEM="$2"; shift 2 ;; - --max-total-tokens) MAX_TOTAL_TOKENS="$2"; shift 2 ;; - --min-free-disk-gb) MIN_FREE_DISK_GB="$2"; shift 2 ;; - --gpu) GPU_ID="$2"; shift 2 ;; - --algo) ALGO="$2"; shift 2 ;; - --real-histograms) REAL_HISTOGRAMS="$2"; shift 2 ;; - --sample-stride) SAMPLE_STRIDE="$2"; shift 2 ;; - --seq-len) SEQ_LEN="$2"; shift 2 ;; - --block-size|--page-size) BLOCK_SIZE="$2"; shift 2 ;; - --batch-sizes) BATCH_SIZES="$2"; shift 2 ;; - --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; - --distributions) DISTRIBUTIONS="$2"; shift 2 ;; - --modes) MAPPING_MODES="$2"; shift 2 ;; - --mapping-hparam) MAPPING_HPARAM="$2"; shift 2 ;; - --repeat) REPEAT="$2"; shift 2 ;; - --warmup) WARMUP="$2"; shift 2 ;; - --num-splits-sweep) NUM_SPLITS_SWEEP="$2"; shift 2 ;; - --skip-autotune) SKIP_AUTOTUNE=1; shift 1 ;; - --pinned-autotune-json) PINNED_AUTOTUNE_JSON="$2"; SKIP_AUTOTUNE=1; shift 2 ;; - *) echo "Unknown option: $1"; exit 1 ;; + --gpu) GPU_ID="$2"; shift 2 ;; + --topk-vals) TOPK_VALS="$2"; shift 2 ;; + --batch-sizes) BATCH_SIZES="$2"; shift 2 ;; + --seq-lens) SEQ_LENS="$2"; shift 2 ;; + --page-size) PAGE_SIZE="$2"; shift 2 ;; + --num-kv-heads) NUM_KV_HEADS="$2"; shift 2 ;; + --distributions) DISTRIBUTIONS="$2"; shift 2 ;; + --reserved-bos) RESERVED_BOS="$2"; shift 2 ;; + --reserved-eos) RESERVED_EOS="$2"; shift 2 ;; + --warmup) WARMUP="$2"; shift 2 ;; + --repeat) REPEAT="$2"; shift 2 ;; + *) echo "Unknown option: $1" >&2; exit 1 ;; esac done export CUDA_VISIBLE_DEVICES="${GPU_ID}" export PYTORCH_CUDA_ALLOC_CONF="${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True}" -export SGL_ENABLE_JIT_DEEPGEMM="${SGL_ENABLE_JIT_DEEPGEMM:-true}" - -if [ -z "${DG_JIT_NVCC_COMPILER:-}" ]; then - if [ -x /usr/local/cuda/bin/nvcc ]; then - export CUDA_HOME="${CUDA_HOME:-/usr/local/cuda}" - export PATH="${CUDA_HOME}/bin:${PATH}" - export DG_JIT_NVCC_COMPILER="${CUDA_HOME}/bin/nvcc" - elif command -v nvcc >/dev/null 2>&1; then - export DG_JIT_NVCC_COMPILER="$(command -v nvcc)" - fi -fi - -MIN_SEQ_LEN=$(( (TOPK_VAL + 4) * BLOCK_SIZE )) -if [ "${SEQ_LEN}" -lt "${MIN_SEQ_LEN}" ]; then - echo "ERROR: --seq-len ${SEQ_LEN} too small for --topk-val ${TOPK_VAL} @ --block-size ${BLOCK_SIZE}." - echo " Minimum: ${MIN_SEQ_LEN} (pages/seg must exceed topk_val + 3 reserved pages)" - exit 1 -fi RESULTS_DIR="${SCRIPT_DIR}/results" mkdir -p "${RESULTS_DIR}" TIMESTAMP=$(date +%Y%m%d_%H%M%S) -MODEL_SLUG="$(echo "${MODEL_NAME}" | tr '/' '_')" -RUN_DIR="${RESULTS_DIR}/parallel_bench_${MODEL_SLUG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" +TOPK_TAG="$(echo ${TOPK_VALS} | tr ' ' '-')" +RUN_DIR="${RESULTS_DIR}/three_way_topk${TOPK_TAG}_bs${PAGE_SIZE}_${TIMESTAMP}" mkdir -p "${RUN_DIR}" - -CALIBRATION_BASE="/var/tmp/zhuominc/vortex_torch/calibration" -MODEL_TAG="$(echo "${MODEL_NAME##*/}" | sed 's/^Q/q/')" -DEFAULT_REAL_HIST="${CALIBRATION_BASE}/raw_histograms_${MODEL_TAG}.npy" -mkdir -p "${CALIBRATION_BASE}" - -if [ -z "${REAL_HISTOGRAMS}" ] && [ -f "${DEFAULT_REAL_HIST}" ]; then - REAL_HISTOGRAMS="${DEFAULT_REAL_HIST}" -fi +JSON_PATH="${RUN_DIR}/three_way.json" +CSV_PATH="${RUN_DIR}/summary.csv" echo "============================================================" -echo "Remap Function Benchmark (Parallel TopK variant)" -echo " Model: ${MODEL_NAME}" -echo " Algorithm: ${ALGO}" -echo " TopK: ${TOPK_VAL}" -echo " Block size: ${BLOCK_SIZE}" -echo " Seq len: ${SEQ_LEN} ($(( SEQ_LEN / BLOCK_SIZE )) pages/seg)" -echo " Batch sizes: ${BATCH_SIZES}" -echo " KV heads: ${NUM_KV_HEADS}" -echo " Distributions: ${DISTRIBUTIONS}" -echo " Mapping modes: ${MAPPING_MODES}" -echo " num_splits sweep:${NUM_SPLITS_SWEEP}" -echo " GPU: ${GPU_ID}" -echo " Real histograms: ${REAL_HISTOGRAMS:-}" -echo " Output: ${RUN_DIR}" +echo "Three-way TopK kernel comparison" +echo " TopK sweep: ${TOPK_VALS}" +echo " Batch sizes: ${BATCH_SIZES}" +echo " Seq lengths: ${SEQ_LENS} (page_size=${PAGE_SIZE})" +echo " KV heads: ${NUM_KV_HEADS}" +echo " Distributions: ${DISTRIBUTIONS}" +echo " GPU: ${GPU_ID}" +echo " Warmup/repeat: ${WARMUP}/${REPEAT}" +echo " Output dir: ${RUN_DIR}" echo "============================================================" -# ── Step 1: Calibrate ──────────────────────────────────────── -if [ -n "${REAL_HISTOGRAMS}" ]; then - echo "" - echo ">>> Step 1: SKIPPED (using provided --real-histograms ${REAL_HISTOGRAMS})" - REAL_HIST_PATH="${REAL_HISTOGRAMS}" -else - echo "" - echo ">>> Step 1: Calibrating ${MODEL_NAME} — collecting real topk histograms" - CALIBRATION_DIR="${CALIBRATION_BASE}/staging_${MODEL_TAG}_topk${TOPK_VAL}_bs${BLOCK_SIZE}_${TIMESTAMP}" - mkdir -p "${CALIBRATION_DIR}" - python "${BENCH_DIR}/calibrate_topk.py" \ - --model-name "${MODEL_NAME}" \ - --topk-val "${TOPK_VAL}" \ - --page-size "${BLOCK_SIZE}" \ - --mem "${MEM}" \ - --max-total-tokens "${MAX_TOTAL_TOKENS}" \ - --min-free-disk-gb "${MIN_FREE_DISK_GB}" \ - --vortex-module-name "${ALGO}" \ - --output-dir "${CALIBRATION_DIR}" \ - 2>&1 | tee "${RUN_DIR}/step1_calibrate.log" - mv -f "${CALIBRATION_DIR}/raw_histograms.npy" "${DEFAULT_REAL_HIST}" - REAL_HIST_PATH="${DEFAULT_REAL_HIST}" - echo ">>> Step 1: Done. raw_histograms -> ${REAL_HIST_PATH}" -fi +# ── Run bench_topk.py with all (B, seq_len, K) combos in one shot ── +# --mapping-modes 0 = MAPPING_NONE → no remap, no autotune needed. +# --remap-bench = drives the per-config table that includes baseline +# (topk_sglang) + naive (topk.cu) + sglang_ori rows. +# --bench-parallel = adds the topk_sglang_merge adaptive measurement +# into each row (under "parallel_ms"). +PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ + --remap-bench \ + --bench-parallel \ + --batch-sizes ${BATCH_SIZES} \ + --num-kv-heads "${NUM_KV_HEADS}" \ + --seq-lens ${SEQ_LENS} \ + --topk-vals ${TOPK_VALS} \ + --page-size "${PAGE_SIZE}" \ + --reserved-bos "${RESERVED_BOS}" \ + --reserved-eos "${RESERVED_EOS}" \ + --distributions ${DISTRIBUTIONS} \ + --mapping-modes 0 \ + --warmup "${WARMUP}" \ + --repeat "${REPEAT}" \ + --output-json "${JSON_PATH}" \ + 2>&1 | tee "${RUN_DIR}/bench_topk.log" -# ── Step 2: Autotune ───────────────────────────────────────── -AUTOTUNE_JSON="${RUN_DIR}/autotune_results.json" -if [ "${SKIP_AUTOTUNE}" -eq 1 ]; then - echo "" - if [ -n "${PINNED_AUTOTUNE_JSON}" ]; then - echo ">>> Step 2: SKIPPED (pinned hparams from ${PINNED_AUTOTUNE_JSON})" - AUTOTUNE_ARGS="--autotune-json ${PINNED_AUTOTUNE_JSON}" - else - echo ">>> Step 2: SKIPPED (using fallback --mapping-hparam ${MAPPING_HPARAM})" - AUTOTUNE_ARGS="" - fi -else - echo "" - echo ">>> Step 2: Auto-tuning hyperparameters by profiled topk kernel latency" - # Autotune on the largest batch size so the picked hparam matches realistic - # decode conditions; the hparam itself is largely batch-invariant. - FIRST_BS="$(echo ${BATCH_SIZES} | awk '{print $NF}')" - PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/autotune_topk_mapping.py" \ - --batch-size "${FIRST_BS}" \ - --num-kv-heads "${NUM_KV_HEADS}" \ - --seq-len "${SEQ_LEN}" \ - --topk-val "${TOPK_VAL}" \ - --page-size "${BLOCK_SIZE}" \ - --real-histograms "${REAL_HIST_PATH}" \ - --warmup "${WARMUP}" \ - --repeat "${REPEAT}" \ - --collect-stats \ - --output-json "${AUTOTUNE_JSON}" \ - 2>&1 | tee "${RUN_DIR}/step2_autotune.log" - echo ">>> Step 2: Done. Autotune results saved to ${AUTOTUNE_JSON}" - AUTOTUNE_ARGS="--autotune-json ${AUTOTUNE_JSON}" -fi +# ── Aggregate to a clean CSV: one row per (B, pages, K, dist) ───── +python - "${JSON_PATH}" "${CSV_PATH}" <<'PY' +import csv, json, sys +src, dst = sys.argv[1], sys.argv[2] +with open(src) as f: + data = json.load(f) +rows = data["results"] if isinstance(data, dict) and "results" in data else data +# Header. Latencies in microseconds. +hdr = [ + "topk", "batch_size", "pages", "distribution", + "cub_topk_us", # topk.cu / topk_output (None when pages > 8192) + "sglang_baseline_us", # topk_sglang.cu / topk_output_sglang + "sglang_fused_us", # topk_sglang.cu / topk_output_sglang_fused (==baseline @ MAPPING_NONE) + "adaptive_us", # topk_sglang_merge.cu / topk_output_adaptive + "speedup_adaptive_vs_fused", + "speedup_adaptive_vs_cub", +] +with open(dst, "w", newline="") as f: + w = csv.writer(f) + w.writerow(hdr) + for r in rows: + B = r["batch_size"]; pg = r["pages_per_seg"]; K = r["topk_val"]; dist = r["distribution"] + cub = r.get("naive_ms") + baseline = r.get("baseline_ms") + none_mode = next((m for m in r["modes"] if m.get("mode_name") == "None"), None) + adaptive = none_mode.get("parallel_ms") if none_mode else None + # At MAPPING_NONE the fused kernel == baseline kernel (no remap branch), + # so report baseline as the fused number too for clarity. + fused = baseline + def us(x): return f"{x*1000:.3f}" if x is not None else "" + sp_f = f"{baseline/adaptive:.3f}" if (adaptive and baseline) else "" + sp_c = f"{cub/adaptive:.3f}" if (adaptive and cub) else "" + w.writerow([K, B, pg, dist, us(cub), us(baseline), us(fused), us(adaptive), sp_f, sp_c]) +print(f"wrote {dst}") +PY -# ── Step 3: Remap + Parallel bench, sweeping num_splits ────── -echo "" -echo ">>> Step 3: Timing baseline / fused / parallel with num_splits sweep" - -for NS in ${NUM_SPLITS_SWEEP}; do - if [ "${NS}" = "auto" ]; then - NS_ARG="--num-splits -1" - NS_TAG="auto" - else - NS_ARG="--num-splits ${NS}" - NS_TAG="ns${NS}" - fi - REMAP_JSON="${RUN_DIR}/remap_bench_${NS_TAG}.json" - LOG="${RUN_DIR}/step3_remap_bench_${NS_TAG}.log" - BENCH_EXTRA=() - [ -n "${REAL_HIST_PATH}" ] && BENCH_EXTRA+=(--real-histograms "${REAL_HIST_PATH}") - echo "" - echo "--- num_splits=${NS_TAG} ---" - PYTHONPATH="${SCRIPT_DIR}/.." python "${BENCH_DIR}/bench_topk.py" \ - --remap-bench \ - --bench-parallel \ - ${NS_ARG} \ - --batch-sizes ${BATCH_SIZES} \ - --num-kv-heads "${NUM_KV_HEADS}" \ - --seq-lens "${SEQ_LEN}" \ - --topk-vals "${TOPK_VAL}" \ - --page-size "${BLOCK_SIZE}" \ - --distributions ${DISTRIBUTIONS} \ - --mapping-modes ${MAPPING_MODES} \ - --mapping-hparam "${MAPPING_HPARAM}" \ - ${AUTOTUNE_ARGS} \ - "${BENCH_EXTRA[@]}" \ - --warmup "${WARMUP}" \ - --repeat "${REPEAT}" \ - --output-json "${REMAP_JSON}" \ - 2>&1 | tee "${LOG}" - echo ">>> num_splits=${NS_TAG}: JSON -> ${REMAP_JSON}" -done - -# ── Summary ─────────────────────────────────────────────────── +# ── Print human-readable summary table ────────────────────────── echo "" echo "============================================================" -echo "Parallel TopK Benchmark Complete" -echo " Model: ${MODEL_NAME}" -echo " Block size: ${BLOCK_SIZE}" -echo " Batch sizes: ${BATCH_SIZES}" -echo " num_splits sweep: ${NUM_SPLITS_SWEEP}" -echo " All outputs in: ${RUN_DIR}/" -echo " autotune_results.json — latency-ranked mapping hparams" -echo " remap_bench_.json — per-config latencies including par_ms" -echo " step{1,2,3}_*.log — pipeline logs" +echo "Summary (us per kernel call; speedup = fused_us / adaptive_us)" echo "============================================================" +column -t -s, "${CSV_PATH}" || cat "${CSV_PATH}" + +echo "" +echo "Done. Results:" +echo " raw JSON: ${JSON_PATH}" +echo " summary: ${CSV_PATH}" +echo " log: ${RUN_DIR}/bench_topk.log" diff --git a/examples/test_topk.py b/examples/test_topk.py deleted file mode 100644 index 01edc7b4..00000000 --- a/examples/test_topk.py +++ /dev/null @@ -1,118 +0,0 @@ -import torch -import triton -# topk_output_sglang expects sparse_kv_indptr before dense_kv_indices (unlike topk_output). -from vortex_torch_C import topk_output_sglang as topk_output - -SEQ_LENS = [4096] -BATCH_SIZES = [256] - -K = 32 -RESERVE_BOS = 0 -RESERVE_EOS = 0 -DEVICE = "cuda" - - -def make_inputs(batch_size, seq_len, k, reserve_bos, reserve_eos, device="cuda"): - dense_kv_indptr = torch.arange( - 0, batch_size * seq_len + 1, seq_len, dtype=torch.int32, device=device - ) - - dense_kv_indices = torch.arange( - 0, batch_size * seq_len, dtype=torch.int32, device=device - ) - - scores = torch.randn( - batch_size * seq_len, dtype=torch.bfloat16, device=device - ) - - # ✅ Fixed CSR-style sparse indptr - sparse_kv_indptr = torch.arange( - 0, batch_size * k + 1, k, dtype=torch.int32, device=device - ) - - sparse_kv_indices = torch.empty( - batch_size * k, dtype=torch.int32, device=device - ) - - return ( - scores, - dense_kv_indptr, - dense_kv_indices, - sparse_kv_indptr, - sparse_kv_indices, - ) - - -def bench_one(batch_size, seq_len, k, reserve_bos, reserve_eos): - ( - scores, - dense_kv_indptr, - dense_kv_indices, - sparse_kv_indptr, - sparse_kv_indices, - ) = make_inputs( - batch_size=batch_size, - seq_len=seq_len, - k=k, - reserve_bos=reserve_bos, - reserve_eos=reserve_eos, - device=DEVICE, - ) - - def fn(): - topk_output( - scores, - dense_kv_indptr, - sparse_kv_indptr, - dense_kv_indices, - sparse_kv_indices, - batch_size, - k, - reserve_bos, - reserve_eos, - seq_len, - ) - - # warmup - for _ in range(10): - fn() - torch.cuda.synchronize() - - ms = triton.testing.do_bench( - fn, - warmup=100, - rep=1000, - return_mode="mean", - ) - return ms - - -def main(): - torch.cuda.init() - - results = {} - - for bs in BATCH_SIZES: - results[bs] = {} - for seq_len in SEQ_LENS: - ms = bench_one( - batch_size=bs, - seq_len=seq_len, - k=K, - reserve_bos=RESERVE_BOS, - reserve_eos=RESERVE_EOS, - ) - results[bs][seq_len] = ms - print(f"bs={bs:>3}, seq_len={seq_len:>4} -> {ms:.6f} ms") - - print("\nLatency table (ms):") - header = "bs\\seq".ljust(10) + "".join(f"{s:>12}" for s in SEQ_LENS) - print(header) - - for bs in BATCH_SIZES: - row = f"{bs:<10}" + "".join(f"{results[bs][s]:>12.4f}" for s in SEQ_LENS) - print(row) - - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/setup.py b/setup.py index 9ff56088..e886eaca 100644 --- a/setup.py +++ b/setup.py @@ -20,8 +20,8 @@ 'csrc/topk_sglang.cu', 'csrc/topk_sglang_profile.cu', 'csrc/topk_sglang_ori.cu', - 'csrc/topk_sglang_parallel.cu', - 'csrc/topk_sglang_cluster.cu', + 'csrc/topk_sglang_merge.cu', + 'csrc/topk_adaptive_profile.cu', ], include_dirs=['csrc'], extra_compile_args={