From dcd8063b20550ca3346475b327e45b41532a929c Mon Sep 17 00:00:00 2001 From: wzj Date: Wed, 18 Mar 2026 13:32:09 +0000 Subject: [PATCH 1/4] demo code --- .../att/decode_att/int8kv/normal/__init__.py | 0 .../normal/int8kv_flash_decodin_stage2.py | 306 ++++++++++++++++ .../int8kv/normal/int8kv_flash_decoding.py | 96 +++++ .../normal/int8kv_flash_decoding_stage1.py | 338 ++++++++++++++++++ .../normal/int8kv_flash_decoding_stage3.py | 95 +++++ .../att/decode_att/int8kv/normal/readme.txt | 1 + 6 files changed, 836 insertions(+) create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/__init__.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decodin_stage2.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage3.py create mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/readme.txt diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decodin_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decodin_stage2.py new file mode 100644 index 0000000000..f5c0b9c395 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decodin_stage2.py @@ -0,0 +1,306 @@ +import torch +import triton +import triton.language as tl +from typing import Optional + +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Dict +from lightllm.common.triton_utils.autotuner import autotune, Autotuner + + +class GQADiverseDecodeStage2KernelConfig(KernelConfigs): + kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage2:v1" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + batch_size: int, + avg_seq_len_in_batch: int, + gqa_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + ) -> dict: + key_params = { + "gqa_group_size": gqa_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + batch_size_config: dict = finded_config[ + min( + finded_config.keys(), + key=lambda x: abs(int(x) - avg_seq_len_in_batch), + ) + ] + config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] + + return config + else: + config = { + "BLOCK_N": 16, + "num_warps": 2, + "num_stages": 2, + } + return config + + @classmethod + def save_config( + cls, + gqa_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + config_json: Dict[int, Dict[int, Dict]], + ): + key_params = { + "gqa_group_size": gqa_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) + + +@triton.jit +def _fwd_kernel_flash_decode_diverse_stage2( + Q, + stride_qbs, + stride_qh, + stride_qd, + K, + K_scale, + stride_kbs, + stride_kh, + stride_kd, + V, + V_scale, + stride_vbs, + stride_vh, + stride_vd, + sm_scale, + Req_to_tokens, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + B_req_idx, + B_Seqlen, + b_shared_seq_len, + Mid_O, # [batch, head, seq_block_num, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_N: tl.constexpr, + KV_QUANT_GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_kv_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + + cur_q_head_range = cur_kv_head * gqa_group_size + tl.arange(0, gqa_group_size) + + offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_d_scale = tl.arange(0, NUM_GROUPS) + cur_batch_shared_len = tl.load(b_shared_seq_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_shared_len + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + store_seq_block = seq_start_block + tl.cdiv(cur_batch_shared_len, BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] + + block_n_size = tl.cdiv( + tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index), + BLOCK_N, + ) + + if block_n_size == 0: + return + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + q = tl.load(Q + off_q) + + sum_exp = tl.zeros([gqa_group_size], dtype=tl.float32) + max_logic = tl.zeros([gqa_group_size], dtype=tl.float32) - float("inf") + acc = tl.zeros([gqa_group_size, BLOCK_HEADDIM], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + n_mask = offs_n_new < cur_batch_end_index + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=n_mask, + other=0, + ).to(tl.int64) + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + # (128, 16) + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_scale = off_k // KV_QUANT_GROUP_SIZE + # (16, 16) + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] + k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) + k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) + k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) + # q (4, 128) k (128, BLOCK_N) + att_value = tl.dot(q, k.to(q.dtype)) + att_value *= sm_scale + att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + v = tl.load( + V + off_v, + mask=n_mask[:, None], + other=0, + ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) + v_scale = tl.load( + V_scale + off_k_scale, + mask=n_mask[None, :], + other=0.0, + ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) + v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype)) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_q_head_range[:, None] * stride_mid_oh + + store_seq_block * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + store_seq_block + tl.store( + Mid_O + off_mid_o, + (acc / sum_exp[:, None]), + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + (max_logic + tl.log(sum_exp)), + ) + return + + +@torch.no_grad() +def flash_decode_stage2( + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + Req_to_tokens: torch.Tensor, + B_req_idx: torch.Tensor, + B_Seqlen: torch.Tensor, + b_shared_seq_len: torch.Tensor, + max_len_in_batch: int, + mid_out: torch.Tensor, + mid_out_logsumexp: torch.Tensor, + block_seq: int, + run_config: Optional[dict] = None, +): + if not run_config: + run_config = GQADiverseDecodeStage2KernelConfig.try_to_get_best_config( + batch_size=int(q.shape[0]), + avg_seq_len_in_batch=max_len_in_batch, + gqa_group_size=int(q.shape[1] // k.shape[1]), + q_head_dim=int(q.shape[2]), + block_seq=block_seq, + out_dtype=q.dtype, + ) + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 + BLOCK_SEQ = block_seq + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, kv_head_num = B_req_idx.shape[0], k.shape[1] + grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + assert triton.next_power_of_2(Lk) == Lk + KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] + assert KV_QUANT_GROUP_SIZE == 8 + NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE + assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS + + assert k.stride() == v.stride() + + _fwd_kernel_flash_decode_diverse_stage2[grid]( + Q=q, + stride_qbs=q.stride(0), + stride_qh=q.stride(1), + stride_qd=q.stride(2), + K=k, + K_scale=k_scale, + stride_kbs=k.stride(0), + stride_kh=k.stride(1), + stride_kd=k.stride(2), + V=v, + V_scale=v_scale, + stride_vbs=v.stride(0), + stride_vh=v.stride(1), + stride_vd=v.stride(2), + sm_scale=sm_scale, + Req_to_tokens=Req_to_tokens, + stride_req_to_tokens_b=Req_to_tokens.stride(0), + stride_req_to_tokens_s=Req_to_tokens.stride(1), + B_req_idx=B_req_idx, + B_Seqlen=B_Seqlen, + b_shared_seq_len=b_shared_seq_len, + Mid_O=mid_out, + stride_mid_ob=mid_out.stride(0), + stride_mid_oh=mid_out.stride(1), + stride_mid_os=mid_out.stride(2), + stride_mid_od=mid_out.stride(3), + Mid_O_LogExpSum=mid_out_logsumexp, # [batch, head, seq_block_num] + stride_mid_o_eb=mid_out_logsumexp.stride(0), + stride_mid_o_eh=mid_out_logsumexp.stride(1), + stride_mid_o_es=mid_out_logsumexp.stride(2), + gqa_group_size=gqa_group_size, + BLOCK_SEQ=block_seq, + BLOCK_HEADDIM=Lk, + BLOCK_N=BLOCK_N, + KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, + NUM_GROUPS=NUM_GROUPS, + num_warps=num_warps, + num_stages=num_stages, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py new file mode 100644 index 0000000000..ad6a8b5b3a --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py @@ -0,0 +1,96 @@ +# 为 diverse mode 定制设计的 int8kv flash decoding attention 实现,可以实现更高效的多样性采样 +import torch +from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops +from lightllm.common.basemodel.infer_struct import InferStateInfo +from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 +from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2 +from .int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 +from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size + + +def token_decode_attention_flash_decoding( + q, + infer_state: InferStateInfo, + cache_k, + cache_k_scale, + cache_v, + cache_v_scale, + out=None, + alloc_tensor_func=torch.empty, + shared_streams_dict={}, +): + if "stream1" not in shared_streams_dict: + shared_streams_dict["stream1"] = torch.cuda.Stream() + if "stream2" not in shared_streams_dict: + shared_streams_dict["stream2"] = torch.cuda.Stream() + + stream1 = shared_streams_dict["stream1"] + stream2 = shared_streams_dict["stream2"] + + q_head_num = q.shape[1] + head_dim = q.shape[2] + + BLOCK_SEQ = 256 + batch_size = infer_state.batch_size + max_kv_seq_len = infer_state.max_kv_seq_len + calcu_shape1 = (batch_size, q_head_num, head_dim) + + o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out + + mid_o = alloc_tensor_func( + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=torch.float32, device="cuda" + ) + mid_o_logexpsum = alloc_tensor_func( + [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=torch.float32, device="cuda" + ) + + current_stream = torch.cuda.current_stream() + + stream1.wait_stream(current_stream) + with torch.cuda.stream(stream1): + flash_decode_stage1( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + b_shared_seq_len=infer_state.b_shared_seq_len, + b_mark_shared_group=infer_state.b_mark_shared_group, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, + max_batch_group_size=get_diverse_max_batch_shared_group_size(), + ) + stream2.wait_stream(current_stream) + with torch.cuda.stream(stream2): + flash_decode_stage2( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_Seqlen=infer_state.b_seq_len, + b_shared_seq_len=infer_state.b_shared_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, + ) + + current_stream.wait_stream(stream1) + current_stream.wait_stream(stream2) + + flash_diverse_decode_stage3( + mid_out=mid_o, + mid_out_logexpsum=mid_o_logexpsum, + B_Seqlen=infer_state.b_seq_len, + b_shared_seq_len=infer_state.b_shared_seq_len, + O=o_tensor.view(calcu_shape1), + block_seq=BLOCK_SEQ, + ) + return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py new file mode 100644 index 0000000000..8040f657c6 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py @@ -0,0 +1,338 @@ +import torch +import triton +import triton.language as tl +from typing import Optional +from lightllm.common.kernel_config import KernelConfigs +from frozendict import frozendict +from functools import lru_cache +from typing import Dict + + +class GQANormalDecodeStage1KernelConfig(KernelConfigs): + kernel_name: str = "_fwd_kernel_flash_decode_normal_stage1:v2" + + @classmethod + @lru_cache(maxsize=200) + def try_to_get_best_config( + cls, + batch_size: int, + avg_seq_len_in_batch: int, + gqa_group_size: int, + max_batch_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + ) -> dict: + key_params = { + "gqa_group_size": gqa_group_size, + "max_batch_group_size": max_batch_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + finded_config = cls.get_the_config(key_params) + + if finded_config: + batch_size_config: dict = finded_config[ + min( + finded_config.keys(), + key=lambda x: abs(int(x) - avg_seq_len_in_batch), + ) + ] + config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] + + return config + else: + config = { + "BLOCK_N": 16, + "num_warps": 2, + "num_stages": 2, + } + return config + + @classmethod + def save_config( + cls, + gqa_group_size: int, + max_batch_group_size: int, + q_head_dim: int, + block_seq: int, + out_dtype: str, + config_json: Dict[int, Dict[int, Dict]], + ): + key_params = { + "gqa_group_size": gqa_group_size, + "max_batch_group_size": max_batch_group_size, + "q_head_dim": q_head_dim, + "block_seq": block_seq, + "out_dtype": str(out_dtype), + } + key_params = frozendict(key_params) + + return cls.store_config(key_params, config_json) + + +@triton.jit +def _fwd_kernel_flash_decode_normal_stage1( + Q, + stride_qbs, + stride_qh, + stride_qd, + K, + K_scale, + stride_kbs, + stride_kh, + stride_kd, + V, + V_scale, + stride_vbs, + stride_vh, + stride_vd, + sm_scale, + Req_to_tokens, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + B_req_idx, + b_shared_seq_len, + b_mark_shared_group, + Mid_O, # [batch, head, seq_block_num, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_BATCH: tl.constexpr, + KV_QUANT_GROUP_SIZE: tl.constexpr, + NUM_GROUPS: tl.constexpr, +): + cur_batch = tl.program_id(0) + shared_batch_group_size = tl.load(b_mark_shared_group + cur_batch) + if shared_batch_group_size == 0: + return + cur_batch_end = cur_batch + 1 + cur_batch = cur_batch - (shared_batch_group_size - 1) + cur_kv_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + + cur_q_head_range = cur_kv_head * gqa_group_size + tl.arange(0, BLOCK_HEAD) + q_head_end_index = (cur_kv_head + 1) * gqa_group_size + cur_q_head_range = tl.where(cur_q_head_range < q_head_end_index, cur_q_head_range, cur_kv_head * gqa_group_size) + + offs_d = tl.arange(0, BLOCK_HEADDIM) + offs_d_scale = tl.arange(0, NUM_GROUPS) + cur_batch_seq_len = tl.load(b_shared_seq_len + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + + offs_batch = cur_batch + tl.arange(0, BLOCK_BATCH) + offs_batch = tl.where(offs_batch < cur_batch_end, offs_batch, cur_batch) + + off_q = offs_batch[:, None, None] * stride_qbs + cur_q_head_range[None, :, None] * stride_qh + offs_d[None, None, :] + + block_n_size = tl.cdiv( + tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index), + BLOCK_N, + ) + + if block_n_size == 0: + return + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + Q_BATCH_HEAD_NUM: tl.constexpr = BLOCK_BATCH * BLOCK_HEAD + q = tl.load(Q + off_q).reshape(Q_BATCH_HEAD_NUM, BLOCK_HEADDIM) + + sum_exp = tl.zeros([Q_BATCH_HEAD_NUM], dtype=tl.float32) + max_logic = tl.zeros([Q_BATCH_HEAD_NUM], dtype=tl.float32) - float("inf") + acc = tl.zeros([Q_BATCH_HEAD_NUM, BLOCK_HEADDIM], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + n_mask = offs_n_new < cur_batch_end_index + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=n_mask, + other=0, + ).to(tl.int64) + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + # (128, 16) + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_scale = off_k // KV_QUANT_GROUP_SIZE + # (16, 16) + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] + k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) + k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) + k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) + att_value = tl.dot(q, k.to(q.dtype)) + att_value *= sm_scale + att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + v = tl.load( + V + off_v, + mask=n_mask[:, None], + other=0, + ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) + v_scale = tl.load( + V_scale + off_k_scale, + mask=n_mask[None, :], + other=0.0, + ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) + v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) + + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype)) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic + + off_mid_o = ( + offs_batch[:, None, None] * stride_mid_ob + + cur_q_head_range[None, :, None] * stride_mid_oh + + seq_start_block * stride_mid_os + + offs_d[None, None, :] + ) + off_mid_o_logexpsum = ( + offs_batch[:, None] * stride_mid_o_eb + cur_q_head_range[None, :] * stride_mid_o_eh + seq_start_block + ) + tl.store( + Mid_O + off_mid_o, + (acc / sum_exp[:, None]).reshape(BLOCK_BATCH, BLOCK_HEAD, BLOCK_HEADDIM), + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + (max_logic + tl.log(sum_exp)).reshape(BLOCK_BATCH, BLOCK_HEAD), + ) + return + + +@torch.no_grad() +def flash_decode_stage1( + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + v: torch.Tensor, + v_scale: torch.Tensor, + Req_to_tokens: torch.Tensor, + B_req_idx: torch.Tensor, + b_shared_seq_len: torch.Tensor, + b_mark_shared_group: torch.Tensor, + max_len_in_batch: int, + mid_out: torch.Tensor, + mid_out_logsumexp: torch.Tensor, + block_seq: int, + max_batch_group_size: int, + run_config: Optional[dict] = None, +): + """ + 该kernel是为多样性生成定制的gqa算子,其中 b_mark_shared_group 是一个shape 为 (batch_size,)的tensor, + 其内容标记那些请求是共享前缀的请求组。举列说明: + b_shared_seq_len : [10, 10, 10, 11, 11, 11, 11] + b_mark_shared_group: [0, 0, 3, 0, 0, 0, 4] + b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于 + 同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。 + """ + if not run_config: + avg_seq_len_in_batch = max_len_in_batch + + run_config = GQADiverseDecodeStage1KernelConfig.try_to_get_best_config( + batch_size=int(q.shape[0]), + avg_seq_len_in_batch=avg_seq_len_in_batch, + gqa_group_size=int(q.shape[1] // k.shape[1]), + max_batch_group_size=max_batch_group_size, + q_head_dim=int(q.shape[2]), + block_seq=block_seq, + out_dtype=q.dtype, + ) + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 + BLOCK_SEQ = block_seq + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + assert Lk in {16, 32, 64, 128} + sm_scale = 1.0 / (Lk ** 0.5) + batch, kv_head_num = B_req_idx.shape[0], k.shape[1] + grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + assert triton.next_power_of_2(Lk) == Lk + KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] + assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE + BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) + BLOCK_BATCH = triton.next_power_of_2(max_batch_group_size) + if BLOCK_HEAD * BLOCK_BATCH < 16: + BLOCK_BATCH = 16 // BLOCK_HEAD + assert k.stride() == v.stride() + NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE + assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS + + assert k.stride() == v.stride() + _fwd_kernel_flash_decode_diverse_stage1[grid]( + Q=q, + stride_qbs=q.stride(0), + stride_qh=q.stride(1), + stride_qd=q.stride(2), + K=k, + K_scale=k_scale, + stride_kbs=k.stride(0), + stride_kh=k.stride(1), + stride_kd=k.stride(2), + V=v, + V_scale=v_scale, + stride_vbs=v.stride(0), + stride_vh=v.stride(1), + stride_vd=v.stride(2), + sm_scale=sm_scale, + Req_to_tokens=Req_to_tokens, + stride_req_to_tokens_b=Req_to_tokens.stride(0), + stride_req_to_tokens_s=Req_to_tokens.stride(1), + B_req_idx=B_req_idx, + b_shared_seq_len=b_shared_seq_len, + b_mark_shared_group=b_mark_shared_group, + Mid_O=mid_out, + stride_mid_ob=mid_out.stride(0), + stride_mid_oh=mid_out.stride(1), + stride_mid_os=mid_out.stride(2), + stride_mid_od=mid_out.stride(3), + Mid_O_LogExpSum=mid_out_logsumexp, # [batch, head, seq_block_num] + stride_mid_o_eb=mid_out_logsumexp.stride(0), + stride_mid_o_eh=mid_out_logsumexp.stride(1), + stride_mid_o_es=mid_out_logsumexp.stride(2), + gqa_group_size=gqa_group_size, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=block_seq, + BLOCK_HEADDIM=Lk, + BLOCK_N=BLOCK_N, + BLOCK_BATCH=BLOCK_BATCH, + KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, + NUM_GROUPS=NUM_GROUPS, + num_warps=num_warps, + num_stages=num_stages, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage3.py new file mode 100644 index 0000000000..a82af03349 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage3.py @@ -0,0 +1,95 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_diverse_decode_stage3( + B_Seqlen, + b_shared_seq_len, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + O, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_shared_len = tl.load(b_shared_seq_len + cur_batch) + + shared_block_n = tl.cdiv(cur_batch_shared_len, BLOCK_SEQ) + not_shared_block_n = tl.cdiv(cur_batch_seq_len - cur_batch_shared_len, BLOCK_SEQ) + + block_n_size = shared_block_n + not_shared_block_n + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_seq_n in range(0, block_n_size, 1): + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + new_max_logic = tl.maximum(tlogic, max_logic) + + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp) + return + + +@torch.no_grad() +def flash_diverse_decode_stage3( + mid_out: torch.Tensor, + mid_out_logexpsum: torch.Tensor, + B_Seqlen: torch.Tensor, + b_shared_seq_len: torch.Tensor, + O: torch.Tensor, + block_seq: int, +): + Lk = mid_out.shape[-1] + assert Lk in {16, 32, 64, 128} + batch, head_num = mid_out.shape[0], mid_out.shape[1] + grid = (batch, head_num) + + _fwd_kernel_flash_diverse_decode_stage3[grid]( + B_Seqlen=B_Seqlen, + b_shared_seq_len=b_shared_seq_len, + Mid_O=mid_out, + Mid_O_LogExpSum=mid_out_logexpsum, + O=O, + stride_mid_ob=mid_out.stride(0), + stride_mid_oh=mid_out.stride(1), + stride_mid_os=mid_out.stride(2), + stride_mid_od=mid_out.stride(3), + stride_mid_o_eb=mid_out_logexpsum.stride(0), + stride_mid_o_eh=mid_out_logexpsum.stride(1), + stride_mid_o_es=mid_out_logexpsum.stride(2), + stride_obs=O.stride(0), + stride_oh=O.stride(1), + stride_od=O.stride(2), + BLOCK_SEQ=block_seq, + BLOCK_DMODEL=Lk, + num_warps=4, + num_stages=2, + ) + return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/readme.txt b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/readme.txt new file mode 100644 index 0000000000..f9a89537d5 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/readme.txt @@ -0,0 +1 @@ +1. 设计一个支持 vsm的 decoding 算子, grid 分配方式 为 (batch_size, vsm_count) \ No newline at end of file From 9ab920ae4ac7efbba20ed1de486ab4d8b3aace09 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 19 Mar 2026 06:07:20 +0000 Subject: [PATCH 2/4] fix --- .../basemodel/attention/triton/int8kv.py | 10 +- .../att/decode_att/int8kv/normal/__init__.py | 1 + .../normal/int8kv_flash_decodin_stage2.py | 306 ----------------- .../int8kv/normal/int8kv_flash_decoding.py | 104 ++---- .../normal/int8kv_flash_decoding_stage1.py | 312 ++++++++---------- ...ge3.py => int8kv_flash_decoding_stage2.py} | 24 +- .../att/decode_att/int8kv/normal/readme.txt | 1 - 7 files changed, 183 insertions(+), 575 deletions(-) delete mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decodin_stage2.py rename lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/{int8kv_flash_decoding_stage3.py => int8kv_flash_decoding_stage2.py} (76%) delete mode 100644 lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/readme.txt diff --git a/lightllm/common/basemodel/attention/triton/int8kv.py b/lightllm/common/basemodel/attention/triton/int8kv.py index 975d7b629c..a47f63a40e 100644 --- a/lightllm/common/basemodel/attention/triton/int8kv.py +++ b/lightllm/common/basemodel/attention/triton/int8kv.py @@ -139,7 +139,7 @@ def decode_att( if enable_diverse_mode_gqa_decode_fast_kernel(): return self.diverse_decode_att(q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, alloc_func=alloc_func) else: - return self.ppl_mha_int8kv_decode_att( + return self.normal_decode_att( q=q, k=k, k_scale=k_scale, @@ -172,7 +172,7 @@ def diverse_decode_att( alloc_tensor_func=alloc_func, ) - def ppl_mha_int8kv_decode_att( + def normal_decode_att( self, q: torch.Tensor, k: torch.Tensor, @@ -180,10 +180,8 @@ def ppl_mha_int8kv_decode_att( v: torch.Tensor, v_scale: torch.Tensor, alloc_func=torch.empty, - ) -> torch.Tensor: - from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import ( - token_decode_attention_flash_decoding, - ) + ): + from ...triton_kernel.att.decode_att.int8kv.normal import token_decode_attention_flash_decoding return token_decode_attention_flash_decoding( q=q, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/__init__.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/__init__.py index e69de29bb2..8cc4aa919b 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/__init__.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/__init__.py @@ -0,0 +1 @@ +from .int8kv_flash_decoding import token_decode_attention_flash_decoding diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decodin_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decodin_stage2.py deleted file mode 100644 index f5c0b9c395..0000000000 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decodin_stage2.py +++ /dev/null @@ -1,306 +0,0 @@ -import torch -import triton -import triton.language as tl -from typing import Optional - -from lightllm.common.kernel_config import KernelConfigs -from frozendict import frozendict -from functools import lru_cache -from typing import Dict -from lightllm.common.triton_utils.autotuner import autotune, Autotuner - - -class GQADiverseDecodeStage2KernelConfig(KernelConfigs): - kernel_name: str = "_fwd_kernel_flash_decode_diverse_stage2:v1" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - batch_size: int, - avg_seq_len_in_batch: int, - gqa_group_size: int, - q_head_dim: int, - block_seq: int, - out_dtype: str, - ) -> dict: - key_params = { - "gqa_group_size": gqa_group_size, - "q_head_dim": q_head_dim, - "block_seq": block_seq, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - batch_size_config: dict = finded_config[ - min( - finded_config.keys(), - key=lambda x: abs(int(x) - avg_seq_len_in_batch), - ) - ] - config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] - - return config - else: - config = { - "BLOCK_N": 16, - "num_warps": 2, - "num_stages": 2, - } - return config - - @classmethod - def save_config( - cls, - gqa_group_size: int, - q_head_dim: int, - block_seq: int, - out_dtype: str, - config_json: Dict[int, Dict[int, Dict]], - ): - key_params = { - "gqa_group_size": gqa_group_size, - "q_head_dim": q_head_dim, - "block_seq": block_seq, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) - - -@triton.jit -def _fwd_kernel_flash_decode_diverse_stage2( - Q, - stride_qbs, - stride_qh, - stride_qd, - K, - K_scale, - stride_kbs, - stride_kh, - stride_kd, - V, - V_scale, - stride_vbs, - stride_vh, - stride_vd, - sm_scale, - Req_to_tokens, - stride_req_to_tokens_b, - stride_req_to_tokens_s, - B_req_idx, - B_Seqlen, - b_shared_seq_len, - Mid_O, # [batch, head, seq_block_num, head_dim] - stride_mid_ob, - stride_mid_oh, - stride_mid_os, - stride_mid_od, - Mid_O_LogExpSum, # [batch, head, seq_block_num] - stride_mid_o_eb, - stride_mid_o_eh, - stride_mid_o_es, - gqa_group_size: tl.constexpr, - BLOCK_SEQ: tl.constexpr, - BLOCK_HEADDIM: tl.constexpr, - BLOCK_N: tl.constexpr, - KV_QUANT_GROUP_SIZE: tl.constexpr, - NUM_GROUPS: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_kv_head = tl.program_id(1) - seq_start_block = tl.program_id(2) - - cur_q_head_range = cur_kv_head * gqa_group_size + tl.arange(0, gqa_group_size) - - offs_d = tl.arange(0, BLOCK_HEADDIM) - offs_d_scale = tl.arange(0, NUM_GROUPS) - cur_batch_shared_len = tl.load(b_shared_seq_len + cur_batch) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_shared_len - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - store_seq_block = seq_start_block + tl.cdiv(cur_batch_shared_len, BLOCK_SEQ) - - off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] - - block_n_size = tl.cdiv( - tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index), - BLOCK_N, - ) - - if block_n_size == 0: - return - - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - q = tl.load(Q + off_q) - - sum_exp = tl.zeros([gqa_group_size], dtype=tl.float32) - max_logic = tl.zeros([gqa_group_size], dtype=tl.float32) - float("inf") - acc = tl.zeros([gqa_group_size, BLOCK_HEADDIM], dtype=tl.float32) - - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - n_mask = offs_n_new < cur_batch_end_index - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=n_mask, - other=0, - ).to(tl.int64) - off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh - # (128, 16) - off_k = off_k_base[None, :] + offs_d[:, None] - # off_k_scale = off_k // KV_QUANT_GROUP_SIZE - # (16, 16) - off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] - k = tl.load(K + off_k, mask=n_mask[None, :], other=0) - k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) - k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) - k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) - k = k * k_scale - k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) - # q (4, 128) k (128, BLOCK_N) - att_value = tl.dot(q, k.to(q.dtype)) - att_value *= sm_scale - att_value = tl.where(n_mask[None, :], att_value, float("-inf")) - off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] - v = tl.load( - V + off_v, - mask=n_mask[:, None], - other=0, - ) - v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) - v_scale = tl.load( - V_scale + off_k_scale, - mask=n_mask[None, :], - other=0.0, - ) - v_scale = tl.trans(v_scale) - v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) - v = v * v_scale - v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) - - cur_max_logic = tl.max(att_value, axis=1) - new_max_logic = tl.maximum(cur_max_logic, max_logic) - - exp_logic = tl.exp(att_value - new_max_logic[:, None]) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale[:, None] - acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype)) - - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) - max_logic = new_max_logic - - off_mid_o = ( - cur_batch * stride_mid_ob - + cur_q_head_range[:, None] * stride_mid_oh - + store_seq_block * stride_mid_os - + offs_d[None, :] - ) - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + store_seq_block - tl.store( - Mid_O + off_mid_o, - (acc / sum_exp[:, None]), - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - (max_logic + tl.log(sum_exp)), - ) - return - - -@torch.no_grad() -def flash_decode_stage2( - q: torch.Tensor, - k: torch.Tensor, - k_scale: torch.Tensor, - v: torch.Tensor, - v_scale: torch.Tensor, - Req_to_tokens: torch.Tensor, - B_req_idx: torch.Tensor, - B_Seqlen: torch.Tensor, - b_shared_seq_len: torch.Tensor, - max_len_in_batch: int, - mid_out: torch.Tensor, - mid_out_logsumexp: torch.Tensor, - block_seq: int, - run_config: Optional[dict] = None, -): - if not run_config: - run_config = GQADiverseDecodeStage2KernelConfig.try_to_get_best_config( - batch_size=int(q.shape[0]), - avg_seq_len_in_batch=max_len_in_batch, - gqa_group_size=int(q.shape[1] // k.shape[1]), - q_head_dim=int(q.shape[2]), - block_seq=block_seq, - out_dtype=q.dtype, - ) - - BLOCK_N = run_config["BLOCK_N"] - num_warps = run_config["num_warps"] - num_stages = run_config["num_stages"] - - assert q.dim() == 3 and k.dim() == 3 and v.dim() == 3 - BLOCK_SEQ = block_seq - assert BLOCK_SEQ % BLOCK_N == 0 - # shape constraints - Lq, Lk = q.shape[-1], k.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 128} - sm_scale = 1.0 / (Lk ** 0.5) - batch, kv_head_num = B_req_idx.shape[0], k.shape[1] - grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) - gqa_group_size = q.shape[1] // k.shape[1] - assert triton.next_power_of_2(Lk) == Lk - KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] - assert KV_QUANT_GROUP_SIZE == 8 - NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE - assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS - - assert k.stride() == v.stride() - - _fwd_kernel_flash_decode_diverse_stage2[grid]( - Q=q, - stride_qbs=q.stride(0), - stride_qh=q.stride(1), - stride_qd=q.stride(2), - K=k, - K_scale=k_scale, - stride_kbs=k.stride(0), - stride_kh=k.stride(1), - stride_kd=k.stride(2), - V=v, - V_scale=v_scale, - stride_vbs=v.stride(0), - stride_vh=v.stride(1), - stride_vd=v.stride(2), - sm_scale=sm_scale, - Req_to_tokens=Req_to_tokens, - stride_req_to_tokens_b=Req_to_tokens.stride(0), - stride_req_to_tokens_s=Req_to_tokens.stride(1), - B_req_idx=B_req_idx, - B_Seqlen=B_Seqlen, - b_shared_seq_len=b_shared_seq_len, - Mid_O=mid_out, - stride_mid_ob=mid_out.stride(0), - stride_mid_oh=mid_out.stride(1), - stride_mid_os=mid_out.stride(2), - stride_mid_od=mid_out.stride(3), - Mid_O_LogExpSum=mid_out_logsumexp, # [batch, head, seq_block_num] - stride_mid_o_eb=mid_out_logsumexp.stride(0), - stride_mid_o_eh=mid_out_logsumexp.stride(1), - stride_mid_o_es=mid_out_logsumexp.stride(2), - gqa_group_size=gqa_group_size, - BLOCK_SEQ=block_seq, - BLOCK_HEADDIM=Lk, - BLOCK_N=BLOCK_N, - KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, - NUM_GROUPS=NUM_GROUPS, - num_warps=num_warps, - num_stages=num_stages, - ) - return diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py index ad6a8b5b3a..b61e8eace2 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding.py @@ -1,95 +1,63 @@ -# 为 diverse mode 定制设计的 int8kv flash decoding attention 实现,可以实现更高效的多样性采样 import torch -from lightllm.utils.light_utils import HAS_LIGHTLLM_KERNEL, light_ops +from typing import Optional from lightllm.common.basemodel.infer_struct import InferStateInfo -from .int8kv_flash_decoding_diverse_stage1 import flash_decode_stage1 -from .int8kv_flash_decoding_diverse_stage2 import flash_decode_stage2 -from .int8kv_flash_decoding_diverse_stage3 import flash_diverse_decode_stage3 -from lightllm.utils.envs_utils import get_diverse_max_batch_shared_group_size +from .int8kv_flash_decoding_stage1 import flash_decode_stage1 +from .int8kv_flash_decoding_stage2 import flash_decode_stage2 +@torch.no_grad() def token_decode_attention_flash_decoding( - q, + q: torch.Tensor, infer_state: InferStateInfo, - cache_k, - cache_k_scale, - cache_v, - cache_v_scale, - out=None, + cache_k: torch.Tensor, + cache_k_scale: torch.Tensor, + cache_v: torch.Tensor, + cache_v_scale: torch.Tensor, + out: Optional[torch.Tensor] = None, alloc_tensor_func=torch.empty, - shared_streams_dict={}, ): - if "stream1" not in shared_streams_dict: - shared_streams_dict["stream1"] = torch.cuda.Stream() - if "stream2" not in shared_streams_dict: - shared_streams_dict["stream2"] = torch.cuda.Stream() - - stream1 = shared_streams_dict["stream1"] - stream2 = shared_streams_dict["stream2"] q_head_num = q.shape[1] head_dim = q.shape[2] BLOCK_SEQ = 256 batch_size = infer_state.batch_size - max_kv_seq_len = infer_state.max_kv_seq_len calcu_shape1 = (batch_size, q_head_num, head_dim) o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out - mid_o = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2, head_dim], dtype=torch.float32, device="cuda" - ) - mid_o_logexpsum = alloc_tensor_func( - [batch_size, q_head_num, max_kv_seq_len // BLOCK_SEQ + 2], dtype=torch.float32, device="cuda" - ) - - current_stream = torch.cuda.current_stream() + # 因为需要分配一些中间tensor,考虑到并行度和中间显存的消耗,batch_size 小的 + # 时候 block_num 较大, batch_size 大的时候 block_num 较小。这样可以达到较好 + # 的显存消耗和性能的平衡。 + if batch_size <= 16: + block_num = 128 + elif batch_size <= 64: + block_num = 64 + else: + block_num = 32 - stream1.wait_stream(current_stream) - with torch.cuda.stream(stream1): - flash_decode_stage1( - q=q.view(calcu_shape1), - k=cache_k, - k_scale=cache_k_scale, - v=cache_v, - v_scale=cache_v_scale, - Req_to_tokens=infer_state.req_manager.req_to_token_indexs, - B_req_idx=infer_state.b_req_idx, - b_shared_seq_len=infer_state.b_shared_seq_len, - b_mark_shared_group=infer_state.b_mark_shared_group, - max_len_in_batch=infer_state.max_kv_seq_len, - mid_out=mid_o, - mid_out_logsumexp=mid_o_logexpsum, - block_seq=BLOCK_SEQ, - max_batch_group_size=get_diverse_max_batch_shared_group_size(), - ) - stream2.wait_stream(current_stream) - with torch.cuda.stream(stream2): - flash_decode_stage2( - q=q.view(calcu_shape1), - k=cache_k, - k_scale=cache_k_scale, - v=cache_v, - v_scale=cache_v_scale, - Req_to_tokens=infer_state.req_manager.req_to_token_indexs, - B_req_idx=infer_state.b_req_idx, - B_Seqlen=infer_state.b_seq_len, - b_shared_seq_len=infer_state.b_shared_seq_len, - max_len_in_batch=infer_state.max_kv_seq_len, - mid_out=mid_o, - mid_out_logsumexp=mid_o_logexpsum, - block_seq=BLOCK_SEQ, - ) + mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda") + mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda") - current_stream.wait_stream(stream1) - current_stream.wait_stream(stream2) + flash_decode_stage1( + q=q.view(calcu_shape1), + k=cache_k, + k_scale=cache_k_scale, + v=cache_v, + v_scale=cache_v_scale, + Req_to_tokens=infer_state.req_manager.req_to_token_indexs, + B_req_idx=infer_state.b_req_idx, + B_seq_len=infer_state.b_seq_len, + max_len_in_batch=infer_state.max_kv_seq_len, + mid_out=mid_o, + mid_out_logsumexp=mid_o_logexpsum, + block_seq=BLOCK_SEQ, + ) - flash_diverse_decode_stage3( + flash_decode_stage2( mid_out=mid_o, mid_out_logexpsum=mid_o_logexpsum, B_Seqlen=infer_state.b_seq_len, - b_shared_seq_len=infer_state.b_shared_seq_len, O=o_tensor.view(calcu_shape1), block_seq=BLOCK_SEQ, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py index 8040f657c6..7ee929b5f9 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py @@ -6,72 +6,20 @@ from frozendict import frozendict from functools import lru_cache from typing import Dict +from lightllm.common.triton_utils.autotuner import autotune +# key_params = { +# "gqa_group_size": gqa_group_size, +# "q_head_dim": q_head_dim, +# "block_seq": block_seq, +# "out_dtype": str(out_dtype), +# } -class GQANormalDecodeStage1KernelConfig(KernelConfigs): - kernel_name: str = "_fwd_kernel_flash_decode_normal_stage1:v2" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - batch_size: int, - avg_seq_len_in_batch: int, - gqa_group_size: int, - max_batch_group_size: int, - q_head_dim: int, - block_seq: int, - out_dtype: str, - ) -> dict: - key_params = { - "gqa_group_size": gqa_group_size, - "max_batch_group_size": max_batch_group_size, - "q_head_dim": q_head_dim, - "block_seq": block_seq, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - batch_size_config: dict = finded_config[ - min( - finded_config.keys(), - key=lambda x: abs(int(x) - avg_seq_len_in_batch), - ) - ] - config = batch_size_config[min(batch_size_config.keys(), key=lambda x: abs(int(x) - batch_size))] - - return config - else: - config = { - "BLOCK_N": 16, - "num_warps": 2, - "num_stages": 2, - } - return config - - @classmethod - def save_config( - cls, - gqa_group_size: int, - max_batch_group_size: int, - q_head_dim: int, - block_seq: int, - out_dtype: str, - config_json: Dict[int, Dict[int, Dict]], - ): - key_params = { - "gqa_group_size": gqa_group_size, - "max_batch_group_size": max_batch_group_size, - "q_head_dim": q_head_dim, - "block_seq": block_seq, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) +# config = { +# "BLOCK_N": 16, +# "num_warps": 2, +# "num_stages": 2, +# } @triton.jit @@ -95,8 +43,7 @@ def _fwd_kernel_flash_decode_normal_stage1( stride_req_to_tokens_b, stride_req_to_tokens_s, B_req_idx, - b_shared_seq_len, - b_mark_shared_group, + b_seq_len, Mid_O, # [batch, head, seq_block_num, head_dim] stride_mid_ob, stride_mid_oh, @@ -111,18 +58,17 @@ def _fwd_kernel_flash_decode_normal_stage1( BLOCK_SEQ: tl.constexpr, BLOCK_HEADDIM: tl.constexpr, BLOCK_N: tl.constexpr, - BLOCK_BATCH: tl.constexpr, KV_QUANT_GROUP_SIZE: tl.constexpr, NUM_GROUPS: tl.constexpr, ): cur_batch = tl.program_id(0) - shared_batch_group_size = tl.load(b_mark_shared_group + cur_batch) - if shared_batch_group_size == 0: - return - cur_batch_end = cur_batch + 1 - cur_batch = cur_batch - (shared_batch_group_size - 1) cur_kv_head = tl.program_id(1) - seq_start_block = tl.program_id(2) + block_index = tl.program_id(2) + grid_block_num = tl.num_programs(2) + cur_batch_seq_len = tl.load(b_seq_len + cur_batch) + req_total_block_num = tl.cdiv(cur_batch_seq_len, BLOCK_SEQ) + if block_index >= req_total_block_num: + return cur_q_head_range = cur_kv_head * gqa_group_size + tl.arange(0, BLOCK_HEAD) q_head_end_index = (cur_kv_head + 1) * gqa_group_size @@ -130,104 +76,127 @@ def _fwd_kernel_flash_decode_normal_stage1( offs_d = tl.arange(0, BLOCK_HEADDIM) offs_d_scale = tl.arange(0, NUM_GROUPS) - cur_batch_seq_len = tl.load(b_shared_seq_len + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ - cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - - offs_batch = cur_batch + tl.arange(0, BLOCK_BATCH) - offs_batch = tl.where(offs_batch < cur_batch_end, offs_batch, cur_batch) - - off_q = offs_batch[:, None, None] * stride_qbs + cur_q_head_range[None, :, None] * stride_qh + offs_d[None, None, :] - - block_n_size = tl.cdiv( - tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index), - BLOCK_N, - ) + off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + off_q) # (BLOCK_HEAD, BLOCK_HEADDIM) - if block_n_size == 0: - return + sum_exp = tl.zeros([BLOCK_HEAD], dtype=tl.float32) + max_logic = tl.zeros([BLOCK_HEAD], dtype=tl.float32) - float("inf") + acc = tl.zeros([BLOCK_HEAD, BLOCK_HEADDIM], dtype=tl.float32) - offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - Q_BATCH_HEAD_NUM: tl.constexpr = BLOCK_BATCH * BLOCK_HEAD - q = tl.load(Q + off_q).reshape(Q_BATCH_HEAD_NUM, BLOCK_HEADDIM) + for iter_block_index in range(block_index, req_total_block_num, grid_block_num): + cur_batch_start_index = iter_block_index * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) - sum_exp = tl.zeros([Q_BATCH_HEAD_NUM], dtype=tl.float32) - max_logic = tl.zeros([Q_BATCH_HEAD_NUM], dtype=tl.float32) - float("inf") - acc = tl.zeros([Q_BATCH_HEAD_NUM, BLOCK_HEADDIM], dtype=tl.float32) + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + block_n_size = tl.cdiv(cur_batch_end_index - cur_batch_start_index, BLOCK_N) - for start_n in range(0, block_n_size, 1): - offs_n_new = start_n * BLOCK_N + offs_n - n_mask = offs_n_new < cur_batch_end_index - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=n_mask, - other=0, - ).to(tl.int64) - off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh - # (128, 16) - off_k = off_k_base[None, :] + offs_d[:, None] - # off_k_scale = off_k // KV_QUANT_GROUP_SIZE - # (16, 16) - off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] - k = tl.load(K + off_k, mask=n_mask[None, :], other=0) - k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) - k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) - k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) - k = k * k_scale - k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) - att_value = tl.dot(q, k.to(q.dtype)) - att_value *= sm_scale - att_value = tl.where(n_mask[None, :], att_value, float("-inf")) - off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] - v = tl.load( - V + off_v, - mask=n_mask[:, None], - other=0, - ) - v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) - v_scale = tl.load( - V_scale + off_k_scale, - mask=n_mask[None, :], - other=0.0, - ) - v_scale = tl.trans(v_scale) - v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) - v = v * v_scale - v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + n_mask = offs_n_new < cur_batch_end_index + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=n_mask, + other=0, + ).to(tl.int64) + off_k_base = k_loc * stride_kbs + cur_kv_head * stride_kh + off_k = off_k_base[None, :] + offs_d[:, None] + # off_k_base // KV_QUANT_GROUP_SIZE 是一种取巧计算stride的方式 + off_k_scale = off_k_base[None, :] // KV_QUANT_GROUP_SIZE + offs_d_scale[:, None] + k = tl.load(K + off_k, mask=n_mask[None, :], other=0) + k = tl.reshape(k, (NUM_GROUPS, KV_QUANT_GROUP_SIZE, BLOCK_N)) + k_scale = tl.load(K_scale + off_k_scale, mask=n_mask[None, :], other=0.0) + k_scale = tl.reshape(k_scale, (NUM_GROUPS, 1, BLOCK_N)) + k = k * k_scale + k = tl.reshape(k, (BLOCK_HEADDIM, BLOCK_N)) + att_value = tl.dot(q, k.to(q.dtype)) + att_value *= sm_scale + att_value = tl.where(n_mask[None, :], att_value, float("-inf")) + off_v = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + v = tl.load( + V + off_v, + mask=n_mask[:, None], + other=0, + ) + v = tl.reshape(v, (BLOCK_N, NUM_GROUPS, KV_QUANT_GROUP_SIZE)) + v_scale = tl.load( + V_scale + off_k_scale, + mask=n_mask[None, :], + other=0.0, + ) + v_scale = tl.trans(v_scale) + v_scale = tl.reshape(v_scale, (BLOCK_N, NUM_GROUPS, 1)) + v = v * v_scale + v = tl.reshape(v, (BLOCK_N, BLOCK_HEADDIM)) - cur_max_logic = tl.max(att_value, axis=1) - new_max_logic = tl.maximum(cur_max_logic, max_logic) + cur_max_logic = tl.max(att_value, axis=1) + new_max_logic = tl.maximum(cur_max_logic, max_logic) - exp_logic = tl.exp(att_value - new_max_logic[:, None]) - logic_scale = tl.exp(max_logic - new_max_logic) - acc *= logic_scale[:, None] - acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype)) + exp_logic = tl.exp(att_value - new_max_logic[:, None]) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale[:, None] + acc += tl.dot(exp_logic.to(q.dtype), v.to(q.dtype)) - sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) - max_logic = new_max_logic + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) + max_logic = new_max_logic off_mid_o = ( - offs_batch[:, None, None] * stride_mid_ob - + cur_q_head_range[None, :, None] * stride_mid_oh - + seq_start_block * stride_mid_os - + offs_d[None, None, :] - ) - off_mid_o_logexpsum = ( - offs_batch[:, None] * stride_mid_o_eb + cur_q_head_range[None, :] * stride_mid_o_eh + seq_start_block + cur_batch * stride_mid_ob + + cur_q_head_range[:, None] * stride_mid_oh + + block_index * stride_mid_os + + offs_d[None, :] ) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + block_index tl.store( Mid_O + off_mid_o, - (acc / sum_exp[:, None]).reshape(BLOCK_BATCH, BLOCK_HEAD, BLOCK_HEADDIM), + (acc / sum_exp[:, None]).reshape(BLOCK_HEAD, BLOCK_HEADDIM), ) tl.store( Mid_O_LogExpSum + off_mid_o_logexpsum, - (max_logic + tl.log(sum_exp)).reshape(BLOCK_BATCH, BLOCK_HEAD), + (max_logic + tl.log(sum_exp)), ) return -@torch.no_grad() +def get_test_configs(): + configs = [] + for block_n in [16, 32, 64, 128]: + for num_warps in [2, 4, 8, 16]: + for num_stages in [2, 4, 6]: + configs.append( + { + "BLOCK_N": block_n, + "num_warps": num_warps, + "num_stages": num_stages, + } + ) + return configs + + +def get_static_key(q, k, block_seq): + key_params = { + "gqa_group_size": int(q.shape[1] // k.shape[1]), + "q_head_dim": int(q.shape[2]), + "block_seq": block_seq, + "out_dtype": str(q.dtype), + } + return key_params + + +def get_run_key(q, max_len_in_batch): + batch_size = q.shape[0] + return batch_size * 1024 * 1024 * 1024 + max_len_in_batch + + +autotune( + kernel_name="_fwd_kernel_flash_decode_normal_stage1:v2", + configs_gen_func=get_test_configs, + static_key_func=get_static_key, + run_key_func=get_run_key, + mutates_args=["mid_out", "mid_out_logsumexp"], +) + + def flash_decode_stage1( q: torch.Tensor, k: torch.Tensor, @@ -236,35 +205,21 @@ def flash_decode_stage1( v_scale: torch.Tensor, Req_to_tokens: torch.Tensor, B_req_idx: torch.Tensor, - b_shared_seq_len: torch.Tensor, - b_mark_shared_group: torch.Tensor, + B_seq_len: torch.Tensor, max_len_in_batch: int, mid_out: torch.Tensor, mid_out_logsumexp: torch.Tensor, block_seq: int, - max_batch_group_size: int, run_config: Optional[dict] = None, ): - """ - 该kernel是为多样性生成定制的gqa算子,其中 b_mark_shared_group 是一个shape 为 (batch_size,)的tensor, - 其内容标记那些请求是共享前缀的请求组。举列说明: - b_shared_seq_len : [10, 10, 10, 11, 11, 11, 11] - b_mark_shared_group: [0, 0, 3, 0, 0, 0, 4] - b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于 - 同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。 - """ + """ """ if not run_config: - avg_seq_len_in_batch = max_len_in_batch - run_config = GQADiverseDecodeStage1KernelConfig.try_to_get_best_config( - batch_size=int(q.shape[0]), - avg_seq_len_in_batch=avg_seq_len_in_batch, - gqa_group_size=int(q.shape[1] // k.shape[1]), - max_batch_group_size=max_batch_group_size, - q_head_dim=int(q.shape[2]), - block_seq=block_seq, - out_dtype=q.dtype, - ) + run_config = { + "BLOCK_N": 16, + "num_warps": 2, + "num_stages": 2, + } BLOCK_N = run_config["BLOCK_N"] num_warps = run_config["num_warps"] @@ -279,21 +234,20 @@ def flash_decode_stage1( assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] - grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) gqa_group_size = q.shape[1] // k.shape[1] assert triton.next_power_of_2(Lk) == Lk KV_QUANT_GROUP_SIZE = v.shape[-1] // v_scale.shape[-1] assert triton.next_power_of_2(KV_QUANT_GROUP_SIZE) == KV_QUANT_GROUP_SIZE BLOCK_HEAD = triton.next_power_of_2(gqa_group_size) - BLOCK_BATCH = triton.next_power_of_2(max_batch_group_size) - if BLOCK_HEAD * BLOCK_BATCH < 16: - BLOCK_BATCH = 16 // BLOCK_HEAD + assert k.stride() == v.stride() NUM_GROUPS = Lk // KV_QUANT_GROUP_SIZE assert triton.next_power_of_2(NUM_GROUPS) == NUM_GROUPS assert k.stride() == v.stride() - _fwd_kernel_flash_decode_diverse_stage1[grid]( + block_num = mid_out.shape[2] + grid = (batch, kv_head_num, block_num) + _fwd_kernel_flash_decode_normal_stage1[grid]( Q=q, stride_qbs=q.stride(0), stride_qh=q.stride(1), @@ -313,8 +267,7 @@ def flash_decode_stage1( stride_req_to_tokens_b=Req_to_tokens.stride(0), stride_req_to_tokens_s=Req_to_tokens.stride(1), B_req_idx=B_req_idx, - b_shared_seq_len=b_shared_seq_len, - b_mark_shared_group=b_mark_shared_group, + b_seq_len=B_seq_len, Mid_O=mid_out, stride_mid_ob=mid_out.stride(0), stride_mid_oh=mid_out.stride(1), @@ -329,7 +282,6 @@ def flash_decode_stage1( BLOCK_SEQ=block_seq, BLOCK_HEADDIM=Lk, BLOCK_N=BLOCK_N, - BLOCK_BATCH=BLOCK_BATCH, KV_QUANT_GROUP_SIZE=KV_QUANT_GROUP_SIZE, NUM_GROUPS=NUM_GROUPS, num_warps=num_warps, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage3.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage2.py similarity index 76% rename from lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage3.py rename to lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage2.py index a82af03349..43dc6051e2 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage3.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage2.py @@ -4,9 +4,8 @@ @triton.jit -def _fwd_kernel_flash_diverse_decode_stage3( +def _fwd_kernel_flash_normal_decode_stage2( B_Seqlen, - b_shared_seq_len, Mid_O, # [batch, head, seq_block_num, head_dim] Mid_O_LogExpSum, # [batch, head, seq_block_num] O, # [batch, head, head_dim] @@ -20,6 +19,7 @@ def _fwd_kernel_flash_diverse_decode_stage3( stride_obs, stride_oh, stride_od, + block_num, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, ): @@ -28,12 +28,8 @@ def _fwd_kernel_flash_diverse_decode_stage3( offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_shared_len = tl.load(b_shared_seq_len + cur_batch) - shared_block_n = tl.cdiv(cur_batch_shared_len, BLOCK_SEQ) - not_shared_block_n = tl.cdiv(cur_batch_seq_len - cur_batch_shared_len, BLOCK_SEQ) - - block_n_size = shared_block_n + not_shared_block_n + block_num = tl.minimum(tl.cdiv(cur_batch_seq_len, BLOCK_SEQ), block_num) sum_exp = 0.0 max_logic = -float("inf") @@ -41,9 +37,9 @@ def _fwd_kernel_flash_diverse_decode_stage3( offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh - for block_seq_n in range(0, block_n_size, 1): - tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os) - tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + for block_index in range(0, block_num, 1): + tv = tl.load(Mid_O + offs_v + block_index * stride_mid_os) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_index) new_max_logic = tl.maximum(tlogic, max_logic) old_scale = tl.exp(max_logic - new_max_logic) @@ -58,11 +54,10 @@ def _fwd_kernel_flash_diverse_decode_stage3( @torch.no_grad() -def flash_diverse_decode_stage3( +def flash_decode_stage2( mid_out: torch.Tensor, mid_out_logexpsum: torch.Tensor, B_Seqlen: torch.Tensor, - b_shared_seq_len: torch.Tensor, O: torch.Tensor, block_seq: int, ): @@ -70,10 +65,10 @@ def flash_diverse_decode_stage3( assert Lk in {16, 32, 64, 128} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) + block_num = mid_out.shape[2] - _fwd_kernel_flash_diverse_decode_stage3[grid]( + _fwd_kernel_flash_normal_decode_stage2[grid]( B_Seqlen=B_Seqlen, - b_shared_seq_len=b_shared_seq_len, Mid_O=mid_out, Mid_O_LogExpSum=mid_out_logexpsum, O=O, @@ -87,6 +82,7 @@ def flash_diverse_decode_stage3( stride_obs=O.stride(0), stride_oh=O.stride(1), stride_od=O.stride(2), + block_num=block_num, BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, num_warps=4, diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/readme.txt b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/readme.txt deleted file mode 100644 index f9a89537d5..0000000000 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/readme.txt +++ /dev/null @@ -1 +0,0 @@ -1. 设计一个支持 vsm的 decoding 算子, grid 分配方式 为 (batch_size, vsm_count) \ No newline at end of file From 9aa0f54add59b0321d0f4e36536ebc776a32e54a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 19 Mar 2026 06:26:10 +0000 Subject: [PATCH 3/4] fix --- lightllm/common/basemodel/attention/create_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lightllm/common/basemodel/attention/create_utils.py b/lightllm/common/basemodel/attention/create_utils.py index 1fcde2a5ca..63bff69a88 100644 --- a/lightllm/common/basemodel/attention/create_utils.py +++ b/lightllm/common/basemodel/attention/create_utils.py @@ -27,13 +27,13 @@ }, "int4kv": { "triton": Int4kvTritonAttBackend, - "fa3": Fp8Fa3AttBackend, - "flashinfer": Fp8FlashInferAttBackend, + # "fa3": Fp8Fa3AttBackend, + # "flashinfer": Fp8FlashInferAttBackend, }, "int8kv": { "triton": Int8kvTritonAttBackend, - "fa3": Fp8Fa3AttBackend, - "flashinfer": Fp8FlashInferAttBackend, + # "fa3": Fp8Fa3AttBackend, + # "flashinfer": Fp8FlashInferAttBackend, }, } @@ -66,7 +66,7 @@ def _auto_select_backend( backend_map = kv_type_to_backend for backend_name in priority_list: - if validate(backend_name): + if backend_name in backend_map[llm_dtype] and validate(backend_name): logger.info(f"Auto-selected {backend_name} backend (validated)") return backend_map[llm_dtype][backend_name] From 6a602bb63018c3ab53462594d3ca3cf2a37428fe Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 19 Mar 2026 07:27:58 +0000 Subject: [PATCH 4/4] fix --- .../normal/int8kv_flash_decoding_stage1.py | 94 ++++++++++---- ....bfloat16,q_head_dim=128}_NVIDIA_H200.json | 122 ++++++++++++++++++ 2 files changed, 192 insertions(+), 24 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/_fwd_kernel_flash_decode_normal_stage1:v3/{block_seq=256,gqa_group_size=4,kv_quant_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py index 7ee929b5f9..76327e93cb 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/int8kv/normal/int8kv_flash_decoding_stage1.py @@ -2,24 +2,7 @@ import triton import triton.language as tl from typing import Optional -from lightllm.common.kernel_config import KernelConfigs -from frozendict import frozendict -from functools import lru_cache -from typing import Dict -from lightllm.common.triton_utils.autotuner import autotune - -# key_params = { -# "gqa_group_size": gqa_group_size, -# "q_head_dim": q_head_dim, -# "block_seq": block_seq, -# "out_dtype": str(out_dtype), -# } - -# config = { -# "BLOCK_N": 16, -# "num_warps": 2, -# "num_stages": 2, -# } +from lightllm.common.triton_utils.autotuner import autotune, Autotuner @triton.jit @@ -173,8 +156,9 @@ def get_test_configs(): return configs -def get_static_key(q, k, block_seq): +def get_static_key(q, k, k_scale, block_seq): key_params = { + "kv_quant_group_size": k.shape[-1] // k_scale.shape[-1], "gqa_group_size": int(q.shape[1] // k.shape[1]), "q_head_dim": int(q.shape[2]), "block_seq": block_seq, @@ -185,18 +169,16 @@ def get_static_key(q, k, block_seq): def get_run_key(q, max_len_in_batch): batch_size = q.shape[0] - return batch_size * 1024 * 1024 * 1024 + max_len_in_batch + return batch_size * 1000 * 1000 * 1000 + max_len_in_batch -autotune( - kernel_name="_fwd_kernel_flash_decode_normal_stage1:v2", +@autotune( + kernel_name="_fwd_kernel_flash_decode_normal_stage1:v3", configs_gen_func=get_test_configs, static_key_func=get_static_key, run_key_func=get_run_key, mutates_args=["mid_out", "mid_out_logsumexp"], ) - - def flash_decode_stage1( q: torch.Tensor, k: torch.Tensor, @@ -288,3 +270,67 @@ def flash_decode_stage1( num_stages=num_stages, ) return + + +if __name__ == "__main__": + # static params + kv_quant_group_size = 8 + gqa_group_size = 4 + q_head_dim = 128 + block_seq = 256 + out_dtype = torch.bfloat16 + + batch_sizes = [1, 8, 16, 32, 64, 128] + decode_lengths = [1024, 2048, 8192, 16384] + + q_head_num = gqa_group_size + + import os + + os.environ["LIGHTLLM_TRITON_AUTOTUNE_LEVEL"] = "2" + Autotuner.start_autotune_warmup() + # autotuing kernel + for batch_size in batch_sizes: + for length in decode_lengths: + # Setup test tensors + q = torch.randn(batch_size, q_head_num, q_head_dim, dtype=out_dtype, device="cuda") + k = torch.ones(batch_size * length, 1, q_head_dim, dtype=torch.int8, device="cuda") + k_scale = torch.randn( + batch_size * length, 1, q_head_dim // kv_quant_group_size, dtype=torch.float32, device="cuda" + ) + v = torch.ones(batch_size * length, 1, q_head_dim, dtype=torch.int8, device="cuda") + v_scale = torch.randn( + batch_size * length, 1, q_head_dim // kv_quant_group_size, dtype=torch.float32, device="cuda" + ) + Req_to_tokens = torch.arange(0, batch_size * length, dtype=torch.int32, device="cuda").view( + batch_size, length + ) + B_req_idx = torch.arange(batch_size, dtype=torch.int32, device="cuda") + B_seq_len = torch.full((batch_size,), length, dtype=torch.int32, device="cuda") + + if batch_size <= 16: + block_num = 128 + elif batch_size <= 64: + block_num = 64 + else: + block_num = 32 + + mid_out = torch.zeros(batch_size, q_head_num, block_num, q_head_dim, dtype=out_dtype, device="cuda") + mid_out_logsumexp = torch.zeros(batch_size, q_head_num, block_num, dtype=out_dtype, device="cuda") + + flash_decode_stage1( + q=q, + k=k, + k_scale=k_scale, + v=v, + v_scale=v_scale, + Req_to_tokens=Req_to_tokens, + B_req_idx=B_req_idx, + B_seq_len=B_seq_len, + max_len_in_batch=length, + mid_out=mid_out, + mid_out_logsumexp=mid_out_logsumexp, + block_seq=block_seq, + ) + + Autotuner.end_autotune_warmup() diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/_fwd_kernel_flash_decode_normal_stage1:v3/{block_seq=256,gqa_group_size=4,kv_quant_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/_fwd_kernel_flash_decode_normal_stage1:v3/{block_seq=256,gqa_group_size=4,kv_quant_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..3f009dd21e --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/_fwd_kernel_flash_decode_normal_stage1:v3/{block_seq=256,gqa_group_size=4,kv_quant_group_size=8,out_dtype=torch.bfloat16,q_head_dim=128}_NVIDIA_H200.json @@ -0,0 +1,122 @@ +{ + "1000001024": { + "BLOCK_N": 128, + "num_stages": 4, + "num_warps": 4 + }, + "1000002048": { + "BLOCK_N": 128, + "num_stages": 4, + "num_warps": 4 + }, + "1000008192": { + "BLOCK_N": 128, + "num_stages": 2, + "num_warps": 16 + }, + "1000016384": { + "BLOCK_N": 128, + "num_stages": 4, + "num_warps": 8 + }, + "128000001024": { + "BLOCK_N": 32, + "num_stages": 6, + "num_warps": 4 + }, + "128000002048": { + "BLOCK_N": 16, + "num_stages": 4, + "num_warps": 2 + }, + "128000008192": { + "BLOCK_N": 16, + "num_stages": 4, + "num_warps": 2 + }, + "128000016384": { + "BLOCK_N": 16, + "num_stages": 4, + "num_warps": 2 + }, + "16000001024": { + "BLOCK_N": 128, + "num_stages": 2, + "num_warps": 16 + }, + "16000002048": { + "BLOCK_N": 128, + "num_stages": 2, + "num_warps": 16 + }, + "16000008192": { + "BLOCK_N": 64, + "num_stages": 2, + "num_warps": 4 + }, + "16000016384": { + "BLOCK_N": 16, + "num_stages": 4, + "num_warps": 2 + }, + "32000001024": { + "BLOCK_N": 128, + "num_stages": 2, + "num_warps": 16 + }, + "32000002048": { + "BLOCK_N": 128, + "num_stages": 4, + "num_warps": 4 + }, + "32000008192": { + "BLOCK_N": 16, + "num_stages": 4, + "num_warps": 2 + }, + "32000016384": { + "BLOCK_N": 16, + "num_stages": 4, + "num_warps": 2 + }, + "64000001024": { + "BLOCK_N": 128, + "num_stages": 4, + "num_warps": 4 + }, + "64000002048": { + "BLOCK_N": 32, + "num_stages": 6, + "num_warps": 4 + }, + "64000008192": { + "BLOCK_N": 16, + "num_stages": 4, + "num_warps": 2 + }, + "64000016384": { + "BLOCK_N": 16, + "num_stages": 4, + "num_warps": 2 + }, + "8000001024": { + "BLOCK_N": 128, + "num_stages": 2, + "num_warps": 16 + }, + "8000002048": { + "BLOCK_N": 128, + "num_stages": 2, + "num_warps": 16 + }, + "8000008192": { + "BLOCK_N": 128, + "num_stages": 4, + "num_warps": 4 + }, + "8000016384": { + "BLOCK_N": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file