diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index 5c1d2b8712..2463bbb8e8 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -11,6 +11,7 @@ from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.common.kv_cache_mem_manager import MemoryManager from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class from lightllm.common.req_manager import ReqManager @@ -53,6 +54,9 @@ class TpPartBaseModel: # infer state class infer_state_class = InferStateInfo + def get_radix_class(self): + return RadixCache + def __init__(self, kvargs): self.args = get_env_start_args() self.run_mode = kvargs["run_mode"] diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 9153349c5d..646f998642 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -62,20 +62,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: raise Exception("need to impl") - def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) - input1 = None + def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) - q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.context_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -87,39 +88,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._get_qkv(input1, infer_state, layer_weight) - input1 = None + def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._get_o(o, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.token_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) ffn_out = self._ffn(input1, infer_state, layer_weight) - input1 = None if self.tp_world_size_ > 1: all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight) - input1 = None + def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) - o = self._context_attention_wrapper_run( q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight ) - q = None o = self._tpsp_get_o(o, infer_state, layer_weight) + return o + + def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -129,14 +133,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): - input1 = self._att_norm(input_embdings, infer_state, layer_weight) - q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight) - input1 = None + def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight) self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None o = self._tpsp_get_o(o, infer_state, layer_weight) + return o + + def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index edf7fe21b9..21b5b7959e 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -7,7 +7,16 @@ QKVROWNMMWeight, COLMMWeight, ) -from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight +from .norm_weight import ( + TpRMSNormWeight, + RMSNormWeight, + GEMMANormWeight, + LayerNormWeight, + NoTpGEMMANormWeight, + QKRMSNORMWeight, + QKGEMMANormWeight, +) from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight from .att_sink_weight import TpAttSinkWeight from .fused_moe.fused_moe_weight import FusedMoeWeight +from .parameter_weight import ParameterWeight, TpParameterWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py index 20416606f9..89a3d24119 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py @@ -71,6 +71,14 @@ def __call__( return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func) +class GEMMANormWeight(RMSNormWeight): + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.weight_name in weights: + self.weight.copy_(weights[self.weight_name]) + self.weight += 1 + self.weight.load_ok = True + + class LayerNormWeight(BaseWeightTpl, PlatformAwareOp): def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None): super().__init__(tp_rank=0, tp_world_size=1) @@ -276,3 +284,23 @@ def __call__( eps: float, ) -> None: return self._forward(q=q, k=k, eps=eps) + + +class QKGEMMANormWeight(QKRMSNORMWeight): + def load_hf_weights(self, weights: Dict[str, torch.Tensor]): + if self.q_weight_name in weights: + self.q_weight.copy_(weights[self.q_weight_name]) + self.q_weight += 1 + self.q_weight.load_ok = True + if self.k_weight_name in weights: + self.k_weight.copy_(weights[self.k_weight_name]) + self.k_weight += 1 + self.k_weight.load_ok = True + + def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple: + assert q.ndim == 2 and self.q_weight.ndim == 1 + assert k.ndim == 2 and self.k_weight.ndim == 1 + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + # So we need to set fp32_multiply to True here. + return qk_rmsnorm_fused_forward(q=q, k=k, w_q=self.q_weight, w_k=self.k_weight, eps=eps, fp32_multiply=True) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py new file mode 100644 index 0000000000..0afb0ecab2 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -0,0 +1,83 @@ +import torch +from typing import Dict, Optional, Tuple +from .base_weight import BaseWeightTpl + + +class ParameterWeight(BaseWeightTpl): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + weight_shape: Optional[Tuple[int, ...]], + bias_name: Optional[str] = None, + bias_shape: Optional[Tuple[int, ...]] = None, + ): + super().__init__() + self.weight_name = weight_name + self.bias_name = bias_name + self.data_type_ = data_type + self.weight_shape = weight_shape + self.bias_shape = bias_shape + self.weight: Optional[torch.Tensor] = None + self.bias: Optional[torch.Tensor] = None + if weight_shape is not None: + self._create_weight() + + def _create_weight(self): + if self.weight_shape is not None: + self.weight = torch.empty(*self.weight_shape, dtype=self.data_type_, device=self.device_id_) + self.weight.load_ok = False + if self.bias_name is not None and self.bias_shape is not None: + self.bias = torch.empty(*self.bias_shape, dtype=self.data_type_, device=self.device_id_) + self.bias.load_ok = False + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_name in weights: + t_weight = weights[self.weight_name] + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True + if self.bias_name is not None and self.bias_name in weights: + t_bias = weights[self.bias_name] + self.bias.copy_(t_bias.to(self.data_type_)) + self.bias.load_ok = True + + def verify_load(self) -> bool: + if self.weight is not None and not getattr(self.weight, "load_ok", False): + return False + if self.bias is not None and not getattr(self.bias, "load_ok", False): + return False + return True + + +class TpParameterWeight(ParameterWeight): + def __init__( + self, + weight_name: str, + data_type: torch.dtype, + split_n_embed: int, + bias_name: Optional[str] = None, + weight_shape: Optional[Tuple[int, ...]] = None, + bias_shape: Optional[Tuple[int, ...]] = None, + ): + self.split_n_embed = split_n_embed + # Calculate TP-split shapes if full shapes are provided + tp_weight_shape = None + tp_bias_shape = None + if weight_shape is not None: + tp_weight_shape = (split_n_embed,) + weight_shape[1:] + if bias_shape is not None: + tp_bias_shape = (split_n_embed,) + bias_shape[1:] + super().__init__(weight_name, data_type, tp_weight_shape, bias_name, tp_bias_shape) + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + start = self.split_n_embed * self.tp_rank_ + end = self.split_n_embed * (self.tp_rank_ + 1) + + if self.weight_name in weights: + t_weight = weights[self.weight_name][start:end] + self.weight.copy_(t_weight.to(self.data_type_)) + self.weight.load_ok = True + if self.bias_name is not None and self.bias_name in weights: + t_bias = weights[self.bias_name][start:end] + self.bias.copy_(t_bias.to(self.data_type_)) + self.bias.load_ok = True diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py new file mode 100644 index 0000000000..bd2aaed530 --- /dev/null +++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py @@ -0,0 +1,247 @@ +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _copy_buffer_kernel( + src_ptr, + dst_ptr, + src_idx_ptr, + dst_idx_ptr, + stride_layer, + stride_slot, + d_size, + BLOCK_D: tl.constexpr, +): + pair_idx = tl.program_id(0) + layer_idx = tl.program_id(1) + block_d = tl.program_id(2) + + stride_layer = stride_layer.to(tl.int64) + stride_slot = stride_slot.to(tl.int64) + + src_slot = tl.load(src_idx_ptr + pair_idx).to(tl.int64) + dst_slot = tl.load(dst_idx_ptr + pair_idx).to(tl.int64) + + offs = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offs < d_size + + base = layer_idx * stride_layer + tl.store( + dst_ptr + base + dst_slot * stride_slot + offs, + tl.load(src_ptr + base + src_slot * stride_slot + offs, mask=mask), + mask=mask, + ) + + +@triton.jit +def _fork_buffer_kernel( + src_ptr, + dst_ptr, + src_idx_ptr, + dst_idx_ptr, + stride_layer, + stride_slot, + d_size, + num_dst_per_src, + BLOCK_D: tl.constexpr, +): + flat_pair = tl.program_id(0) + layer_idx = tl.program_id(1) + block_d = tl.program_id(2) + + src_chunk = flat_pair // num_dst_per_src + + stride_layer = stride_layer.to(tl.int64) + stride_slot = stride_slot.to(tl.int64) + + src_slot = tl.load(src_idx_ptr + src_chunk).to(tl.int64) + dst_slot = tl.load(dst_idx_ptr + flat_pair).to(tl.int64) + + offs = block_d * BLOCK_D + tl.arange(0, BLOCK_D) + mask = offs < d_size + + base = layer_idx * stride_layer + tl.store( + dst_ptr + base + dst_slot * stride_slot + offs, + tl.load(src_ptr + base + src_slot * stride_slot + offs, mask=mask), + mask=mask, + ) + + +def _get_buffer_copy_configs(): + configs = [] + for block_d in [128, 256, 512, 1024, 2048, 4096]: + for num_warps in [1, 2, 4, 8]: + for num_stages in [1, 2]: + configs.append({"BLOCK_D": block_d, "num_warps": num_warps, "num_stages": num_stages}) + return configs + + +def _get_copy_static_key( + src_buffer: torch.Tensor, +): + d_size = ( + src_buffer.shape[2] + if src_buffer.ndim == 3 + else src_buffer.numel() // (src_buffer.shape[0] * src_buffer.shape[1]) + ) + return { + "dtype": str(src_buffer.dtype), + "d_size": d_size, + "layer_num": src_buffer.shape[0], + "ndim": src_buffer.ndim, + } + + +def _get_copy_run_key(src_buffer: torch.Tensor): + return 0 + + +def _get_fork_static_key(src_buffer: torch.Tensor): + d_size = ( + src_buffer.shape[2] + if src_buffer.ndim == 3 + else src_buffer.numel() // (src_buffer.shape[0] * src_buffer.shape[1]) + ) + return { + "dtype": str(src_buffer.dtype), + "d_size": d_size, + "layer_num": src_buffer.shape[0], + "ndim": src_buffer.ndim, + } + + +def _get_fork_run_key(src_buffer: torch.Tensor): + return 0 + + +def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor: + if buffer.ndim == 3: + return buffer + L, B = buffer.shape[:2] + return buffer.view(L, B, -1) + + +@autotune( + kernel_name="mamba_buffer_copy_1d:v1", + configs_gen_func=_get_buffer_copy_configs, + static_key_func=_get_copy_static_key, + run_key_func=_get_copy_run_key, +) +def _copy_mamba_buffer_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, + run_config: dict = None, +): + if not run_config: + d_size = src_buffer.shape[2] + BLOCK_D = min(4096, triton.next_power_of_2(d_size)) + num_warps = 4 if BLOCK_D >= 1024 else 2 + run_config = {"BLOCK_D": BLOCK_D, "num_warps": num_warps, "num_stages": 1} + + config = run_config + BLOCK_D = config["BLOCK_D"] + num_pairs = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d_size = src_buffer.shape[2] + + num_blocks_d = triton.cdiv(d_size, BLOCK_D) + + grid = (num_pairs, layer_num, num_blocks_d) + _copy_buffer_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes, + src_buffer.stride(0), + src_buffer.stride(1), + d_size, + BLOCK_D=BLOCK_D, + num_warps=config["num_warps"], + num_stages=config["num_stages"], + ) + + +@autotune( + kernel_name="mamba_buffer_fork_1d:v1", + configs_gen_func=_get_buffer_copy_configs, + static_key_func=_get_fork_static_key, + run_key_func=_get_fork_run_key, +) +def _fork_mamba_buffer_autotuned( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes_flat: torch.Tensor, + num_dst_per_src: int, + run_config: dict = None, +): + if not run_config: + d_size = src_buffer.shape[2] + BLOCK_D = min(4096, triton.next_power_of_2(d_size)) + num_warps = 4 if BLOCK_D >= 1024 else 2 + run_config = {"BLOCK_D": BLOCK_D, "num_warps": num_warps, "num_stages": 1} + + config = run_config + BLOCK_D = config["BLOCK_D"] + num_src = src_indexes.shape[0] + layer_num = src_buffer.shape[0] + d_size = src_buffer.shape[2] + + num_blocks_d = triton.cdiv(d_size, BLOCK_D) + total_pairs = num_src * num_dst_per_src + + grid = (total_pairs, layer_num, num_blocks_d) + _fork_buffer_kernel[grid]( + src_buffer, + dst_buffer, + src_indexes, + dst_indexes_flat, + src_buffer.stride(0), + src_buffer.stride(1), + d_size, + num_dst_per_src, + BLOCK_D=BLOCK_D, + num_warps=config["num_warps"], + num_stages=config["num_stages"], + ) + + +def copy_mamba_buffer( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + assert src_buffer.shape == dst_buffer.shape + assert src_indexes.shape == dst_indexes.shape and src_indexes.ndim == 1 + + src_flat = _flatten_trailing_dims(src_buffer) + dst_flat = _flatten_trailing_dims(dst_buffer) + _copy_mamba_buffer_autotuned(src_flat, dst_flat, src_indexes, dst_indexes) + + +def fork_mamba_buffer( + src_buffer: torch.Tensor, + dst_buffer: torch.Tensor, + src_indexes: torch.Tensor, + dst_indexes: torch.Tensor, +): + assert src_buffer.shape == dst_buffer.shape + assert src_indexes.ndim == 1 + assert dst_indexes.ndim == 2, f"dst_indexes must be 2D [num_src, num_dst_per_src], got {dst_indexes.shape}" + assert ( + dst_indexes.shape[0] == src_indexes.shape[0] + ), f"Mismatch: src_indexes {src_indexes.shape[0]} vs dst_indexes rows {dst_indexes.shape[0]}" + + num_dst_per_src = dst_indexes.shape[1] + dst_indexes_flat = dst_indexes.reshape(-1).contiguous() + + src_flat = _flatten_trailing_dims(src_buffer) + dst_flat = _flatten_trailing_dims(dst_buffer) + _fork_mamba_buffer_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat, num_dst_per_src) diff --git a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py index fab0141158..e152a8dd83 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py @@ -78,11 +78,12 @@ def _qk_rms_norm_fused_kernel( WK_ptr, stride_k_row, stride_k_col, + eps, # Dimensions num_heads_q: tl.constexpr, # Q 的头数 (用于判断边界) head_dim: tl.constexpr, - eps, BLOCK_SIZE: tl.constexpr, + FP32_MULTIPLY: tl.constexpr, ): # PID 0: 处理第几个 Token (Row) row_idx = tl.program_id(0) @@ -108,13 +109,15 @@ def _qk_rms_norm_fused_kernel( # RMSNorm 计算 var = tl.sum(x * x, axis=0) / head_dim rstd = 1 / tl.sqrt(var + eps) + x *= rstd # 加载 Q 的权重 (假设所有 Head 共享同一组 dim=head_dim 的权重) w = tl.load(WQ_ptr + offs) - - x *= rstd - y = x.to(w.dtype) * w - + if FP32_MULTIPLY: + w = w.to(tl.float32) + else: + x = x.to(Q_ptr.dtype.element_ty) + y = (x * w).to(Q_ptr.dtype.element_ty) # 写回 Q tl.store(Q_ptr + q_ptr_offset, y) @@ -132,18 +135,27 @@ def _qk_rms_norm_fused_kernel( # RMSNorm 计算 var = tl.sum(x * x, axis=0) / head_dim rstd = 1 / tl.sqrt(var + eps) + x *= rstd # 加载 K 的权重 w = tl.load(WK_ptr + offs) - x *= rstd - - y = x.to(w.dtype) * w - + if FP32_MULTIPLY: + w = w.to(tl.float32) + else: + x = x.to(K_ptr.dtype.element_ty) + y = (x * w).to(K_ptr.dtype.element_ty) # 写回 K tl.store(K_ptr + k_ptr_offset, y) -def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, eps: float = 1e-6): +def qk_rmsnorm_fused_forward( + q: torch.Tensor, + k: torch.Tensor, + w_q: torch.Tensor, + w_k: torch.Tensor, + eps: float = 1e-6, + fp32_multiply: bool = False, +): """ In-place RMSNorm for both Q and K in a single kernel launch. Supports GQA (different number of heads for Q and K). @@ -197,6 +209,7 @@ def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor num_heads_q=num_heads_q, head_dim=head_dim, eps=eps, + FP32_MULTIPLY=fp32_multiply, BLOCK_SIZE=BLOCK_SIZE, num_warps=4, ) diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py new file mode 100644 index 0000000000..9d2d372e17 --- /dev/null +++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py @@ -0,0 +1,207 @@ +from typing import List, Tuple, Union + +import torch +import numpy as np + +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_mamba_buffer, fork_mamba_buffer +from lightllm.utils.log_utils import init_logger +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt + +logger = init_logger(__name__) + +MAMBA_CACHE_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_mamba_cache_can_use_num" + + +class LayerCache: + def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_num: int): + self.size = size + self.dtype = dtype + self.shape = shape + self.layer_num = layer_num + + self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda") + + def get_cell_size(self): + return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) + + +class MambaCacheManager: + def __init__( + self, + size: int, + layer_num: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + ): + # init the mem state + self.size = size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._mem_state_return = torch.arange( + 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self._return_start = 0 + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num = SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}") + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.HOLD_TOKEN_MEMINDEX = self.size + + # init the layer cache + self.conv_state_cache = LayerCache(size, conv_state_dtype, conv_state_shape, layer_num) + self.ssm_state_cache = LayerCache(size, ssm_state_dtype, ssm_state_shape, layer_num) + self.HOLD_BUFFER_INDEX = size + + logger.warning( + f"Linear attention state cache size: {size}\n" + f"Conv state use : " + f"{self.conv_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" + f"Ssm state use : " + f"{self.ssm_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n" + ) + + def get_mamba_cache(self, layer_idx: int): + conv_state = self.conv_state_cache.buffer[layer_idx] + ssm_state = self.ssm_state_cache.buffer[layer_idx] + return conv_state, ssm_state + + def copy_state_buffers(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor): + copy_mamba_buffer( + self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes + ) + copy_mamba_buffer( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes + ) + + def fork_state_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + fork_mamba_buffer( + self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) + fork_mamba_buffer( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) + + def fork_ssm_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor): + """ + Fork ONLY SSM states (not conv states) from source indices to destination indices. + + This is used for MTP mode where each buffer maintains its own independent conv state, + but SSM states need to be synchronized. + """ + fork_mamba_buffer( + self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes + ) + + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" + + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """ + Free the allocated cache buffers and clear them. + + Args: + free_index: Buffer indices to free (tensor or list of ints) + """ + # Convert to tensor if needed for indexing + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=torch.long, device="cuda") + else: + free_index_tensor = free_index.to(device="cuda", dtype=torch.long) + + # Clear the buffers for the freed indices + # Shape: [layer_num, buffer_index, *shape] + self.conv_state_cache.buffer[:, free_index_tensor, ...] = 0 + self.ssm_state_cache.buffer[:, free_index_tensor, ...] = 0 + + # update the mem state + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device) + self.mem_state[start:end] = free_index_tensor + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + + return + + def free_all(self): + self.conv_state_cache.buffer.fill_(0) + self.ssm_state_cache.buffer.fill_(0) + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) + + return + + def resize_mem(self, new_size): + """ + just for test code + """ + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return + + +class ReadOnlyStaticsMambaCacheManager: + """ + 读取一些统计信息 + """ + + def __init__(self) -> None: + args = get_env_start_args() + self.global_world_size = args.tp + self.node_world_size = args.tp // args.nnodes + self.dp_world_size = self.global_world_size // args.dp + # 兼容多机 dp size=1 纯 tp 模式的情况 + self.is_multinode_tp = args.dp == 1 and args.nnodes > 1 + self.shared_tp_can_use_token_nums = [ + SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{rank_in_node}") + for rank_in_node in range(0, self.node_world_size, self.dp_world_size) + ] + + def get_unrefed_token_num(self, dp_rank_in_node: int): + if self.is_multinode_tp: + return self.shared_tp_can_use_token_nums[0].get_value() + return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value() diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 33bdca4475..bbe2bb4a3b 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -3,6 +3,7 @@ from lightllm.utils.log_utils import init_logger from .kv_cache_mem_manager import MemoryManager from typing import List, Optional + from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args @@ -67,6 +68,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num) self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num + self.req_to_buffer_index = None def alloc(self): return self.req_list.alloc() @@ -93,6 +95,11 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + @property + def has_recurrent_state(self): + """Whether this model uses per-request recurrent state buffers (e.g. Mamba/linear attention).""" + return self.req_to_buffer_index is not None + class ReqSamplingParamsManager: """ @@ -232,3 +239,28 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List): p_token_counts_tensor.cuda(non_blocking=True), p_cumsum_seq_len_tensor.cuda(non_blocking=True), ) + + +class ReqManagerForMamba(ReqManager): + def __init__(self, max_request_num, max_sequence_length, mem_manager): + from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager + + super().__init__(max_request_num, max_sequence_length, mem_manager) + self.mtp_step = get_env_start_args().mtp_step + self.buffer_mem_manager: MambaCacheManager = self.mem_manager.mamba_cache_mem_manager + self.req_to_buffer_index = torch.zeros( + (self.max_request_num + 1, self.mtp_step + 1), dtype=torch.int32, device="cuda" + ) + self.req_to_buffer_index[self.HOLD_REQUEST_ID, :] = self.buffer_mem_manager.HOLD_BUFFER_INDEX + + def free_buffer(self, free_buffer_indexes: List[int]): + self.buffer_mem_manager.free(free_buffer_indexes) + return + + def alloc_buffer_for_req(self, req_index: torch.Tensor): + num_reqs = req_index.shape[0] + num_buffers_per_req = self.mtp_step + 1 + buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req) + if not buffer_indexes.is_cuda: + buffer_indexes = buffer_indexes.cuda() + self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..cc5c68eb79 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..b6e5109b62 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..511935b4cf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 128, + "num_stages": 4, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 128, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..cc5c68eb79 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..1038611f6a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..4bc06d07d9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..7421097fa4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..f1159e4357 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..892c20e78d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..d831f32c4a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..40cdc996b9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,12 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..833062ec2f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 2 + }, + "100": { + "num_warps": 1 + }, + "1024": { + "num_warps": 1 + }, + "128": { + "num_warps": 1 + }, + "16": { + "num_warps": 2 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 1 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 1 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 8 + }, + "8": { + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..5f2cf9465b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 2 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 2 + }, + "2048": { + "num_warps": 2 + }, + "256": { + "num_warps": 2 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 2 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..c8a1841674 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 2 + }, + "1024": { + "num_warps": 8 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 4 + }, + "32": { + "num_warps": 2 + }, + "4096": { + "num_warps": 8 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..a97cabf8b2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "4": { + "BK": 64, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..786624883f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 64, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..eaca03cf75 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,12 @@ +{ + "2": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..d9064e5d6a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..baef19d90c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 32, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "64": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "8": { + "BLK_HEADS": 64, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..90ac24c408 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 1 + }, + "1024": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 4, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 8, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 16, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..31d7a6e203 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,118 @@ +{ + "1024": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "12": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "1200": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "12288": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "128": { + "BLOCK_N": 256, + "num_warps": 8 + }, + "131072": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "1536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "1600": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "192": { + "BLOCK_N": 512, + "num_warps": 1 + }, + "196608": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "24576": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "256": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "262144": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "3072": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "384": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "4096": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "49152": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "512": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "64": { + "BLOCK_N": 256, + "num_warps": 8 + }, + "65536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "768": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "8": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "96": { + "BLOCK_N": 512, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0042ef8a2a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..54a5967071 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..bb78d1dd84 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..1552d8bf1a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..08cbfd85c3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "512": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "800": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..13a070b8f0 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..169a148799 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 512, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 128, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..5022588ef5 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 2, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 8 + }, + "4096": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..4ae96d02d1 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 32, + "NUM_STAGE": 2, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 2, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 16 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..28c654f3d2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 3, + "num_warps": 2 + }, + "100": { + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "num_stages": 3, + "num_warps": 2 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 2, + "num_warps": 1 + }, + "2048": { + "num_stages": 4, + "num_warps": 1 + }, + "256": { + "num_stages": 3, + "num_warps": 1 + }, + "32": { + "num_stages": 3, + "num_warps": 2 + }, + "4096": { + "num_stages": 4, + "num_warps": 1 + }, + "64": { + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "num_stages": 2, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..08b0d5e5bc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "num_stages": 3, + "num_warps": 1 + }, + "1024": { + "num_stages": 2, + "num_warps": 1 + }, + "128": { + "num_stages": 4, + "num_warps": 4 + }, + "16": { + "num_stages": 4, + "num_warps": 4 + }, + "16384": { + "num_stages": 2, + "num_warps": 4 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 5, + "num_warps": 2 + }, + "32": { + "num_stages": 5, + "num_warps": 8 + }, + "4096": { + "num_stages": 2, + "num_warps": 1 + }, + "64": { + "num_stages": 2, + "num_warps": 1 + }, + "8": { + "num_stages": 5, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..0d871841ed --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "num_stages": 1, + "num_warps": 1 + }, + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "128": { + "num_stages": 2, + "num_warps": 1 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 4, + "num_warps": 1 + }, + "2048": { + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "num_stages": 4, + "num_warps": 1 + }, + "32": { + "num_stages": 2, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 1 + }, + "64": { + "num_stages": 4, + "num_warps": 1 + }, + "8": { + "num_stages": 3, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..9f3a8dcb25 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "131072": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json new file mode 100644 index 0000000000..72026f01c4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json @@ -0,0 +1,74 @@ +{ + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "131072": { + "BLOCK_M": 64, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "32768": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "512": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "800": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..5d9216c2ea --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 128, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..338af08a1d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 2 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..4bc06d07d9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..f1159e4357 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..131da59770 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "8": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..c8fa422e0c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,14 @@ +{ + "2": { + "BK": 128, + "BV": 64, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..2af1b86e90 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..40cdc996b9 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,12 @@ +{ + "2": { + "BV": 32, + "num_stages": 3, + "num_warps": 4 + }, + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..a40eda35d4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 8 + }, + "100": { + "num_warps": 1 + }, + "1024": { + "num_warps": 8 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 1 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 1 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..808ed9a7fc --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 1 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 2 + }, + "128": { + "num_warps": 8 + }, + "16": { + "num_warps": 8 + }, + "16384": { + "num_warps": 1 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 8 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b08208be2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 1 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 1 + }, + "128": { + "num_warps": 1 + }, + "16": { + "num_warps": 1 + }, + "16384": { + "num_warps": 2 + }, + "2048": { + "num_warps": 8 + }, + "256": { + "num_warps": 8 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 8 + }, + "64": { + "num_warps": 1 + }, + "8": { + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..27e4804a61 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 64, + "num_stages": 2, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..fb62cf8259 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "8": { + "BK": 32, + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 0000000000..7749b3601f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,12 @@ +{ + "2": { + "BK": 64, + "num_stages": 2, + "num_warps": 4 + }, + "4": { + "BK": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..49c4dc63d1 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..ad8d397d3b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 8, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "1024": { + "BLK_HEADS": 32, + "num_warps": 1 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "16384": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "2048": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 16, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..907575d960 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "16": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "16384": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "32": { + "BLK_HEADS": 64, + "num_warps": 4 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "64": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "8": { + "BLK_HEADS": 64, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..55ccb24a65 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,118 @@ +{ + "1024": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "128": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "131072": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "1536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16": { + "BLOCK_N": 256, + "num_warps": 4 + }, + "1600": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "16384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "192": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "24": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "2400": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "24576": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "256": { + "BLOCK_N": 512, + "num_warps": 2 + }, + "262144": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "3072": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "32768": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "384": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "393216": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "4096": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "49152": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "512": { + "BLOCK_N": 256, + "num_warps": 4 + }, + "6144": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "64": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "65536": { + "BLOCK_N": 128, + "num_warps": 1 + }, + "768": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "8": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "800": { + "BLOCK_N": 64, + "num_warps": 2 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 2 + }, + "98304": { + "BLOCK_N": 128, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..e5a383f23f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..56c79e3a43 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..4843ed8ccf --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..662875ecdb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..3c0e605b00 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..1f8134fa64 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "131072": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32768": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8192": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..d82ca44a21 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "163840": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "20480": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 8 + }, + "320": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "40960": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 0000000000..96eabffc42 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "2048": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..69a0e9ca42 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 2048, + "num_stages": 2, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..9de6716c3c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 4096, + "num_stages": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..2e3a3febbb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..5c9f40590b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 1024, + "num_stages": 1, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..0a3facae38 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 2048, + "num_stages": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..9cdaab5ace --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 4096, + "num_stages": 2, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..2e3a3febbb --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 512, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json new file mode 100644 index 0000000000..889f6ab71b --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "0": { + "BLOCK_D": 1024, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..07e5e6875f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "100": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "16384": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "2048": { + "BLOCK_SIZE": 128, + "num_warps": 2 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 512, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..ff4632955f --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 256, + "BLOCK_M": 4, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "1024": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 16 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 128, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 4 + }, + "4096": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 2 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 64, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 0000000000..89ab51ff8c --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "1": { + "BLOCK_DIM": 128, + "BLOCK_M": 4, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "100": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "128": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "16384": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "2048": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "256": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "32": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "4096": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 1 + }, + "64": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..f4d29554da --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 1, + "num_warps": 8 + }, + "100": { + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "num_stages": 5, + "num_warps": 2 + }, + "16384": { + "num_stages": 1, + "num_warps": 2 + }, + "2048": { + "num_stages": 2, + "num_warps": 2 + }, + "256": { + "num_stages": 3, + "num_warps": 4 + }, + "32": { + "num_stages": 5, + "num_warps": 2 + }, + "4096": { + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "num_stages": 4, + "num_warps": 4 + }, + "8": { + "num_stages": 5, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..8605a91680 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 2, + "num_warps": 1 + }, + "100": { + "num_stages": 4, + "num_warps": 4 + }, + "1024": { + "num_stages": 3, + "num_warps": 2 + }, + "128": { + "num_stages": 2, + "num_warps": 1 + }, + "16": { + "num_stages": 3, + "num_warps": 4 + }, + "16384": { + "num_stages": 3, + "num_warps": 2 + }, + "2048": { + "num_stages": 3, + "num_warps": 2 + }, + "256": { + "num_stages": 5, + "num_warps": 2 + }, + "32": { + "num_stages": 2, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 2 + }, + "64": { + "num_stages": 2, + "num_warps": 1 + }, + "8": { + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..5b3e656b6d --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,18 @@ +{ + "1024": { + "num_stages": 5, + "num_warps": 1 + }, + "16384": { + "num_stages": 4, + "num_warps": 2 + }, + "2048": { + "num_stages": 2, + "num_warps": 2 + }, + "4096": { + "num_stages": 3, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..ada783ef92 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,50 @@ +{ + "1": { + "num_stages": 3, + "num_warps": 4 + }, + "100": { + "num_stages": 2, + "num_warps": 1 + }, + "1024": { + "num_stages": 2, + "num_warps": 1 + }, + "128": { + "num_stages": 5, + "num_warps": 2 + }, + "16": { + "num_stages": 3, + "num_warps": 2 + }, + "16384": { + "num_stages": 3, + "num_warps": 1 + }, + "2048": { + "num_stages": 5, + "num_warps": 2 + }, + "256": { + "num_stages": 3, + "num_warps": 2 + }, + "32": { + "num_stages": 3, + "num_warps": 1 + }, + "4096": { + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "num_stages": 3, + "num_warps": 4 + }, + "8": { + "num_stages": 3, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..12993b0231 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,74 @@ +{ + "10": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "163840": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "20480": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 0000000000..0a0f01fe7a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,98 @@ +{ + "10": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "131072": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "16384": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "163840": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "20480": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "32768": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 2, + "num_warps": 1 + }, + "8192": { + "BLOCK_M": 8, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index f2e29d4a88..2caee91709 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -7,6 +7,7 @@ from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel from lightllm.models.internlm2.model import Internlm2TpPartModel @@ -39,4 +40,6 @@ ) from lightllm.models.gpt_oss.model import GptOssTpPartModel from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.models.qwen3_5_moe.model import Qwen3_5MOETpPartModel from .registry import get_model, get_model_class diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py index 0566c9f1c6..e42a6191e6 100644 --- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py @@ -31,7 +31,7 @@ def _parse_config(self): head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] self.head_dim = self.network_config_.get("head_dim", head_dim) self.n_embed = self.network_config_["hidden_size"] - self.n_inter = self.network_config_["intermediate_size"] + self.n_inter = self.network_config_.get("intermediate_size", -1) def _init_weight_names(self): self._q_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py index 7156a5ce23..825a985b46 100644 --- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py +++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py @@ -227,14 +227,14 @@ def _init_datatype(self): def rot_pos_emb(self, grid_thw): pos_ids = [] s = self.spatial_merge_size - for _, h, w in grid_thw: + for t, h, w in grid_thw: pos_shape = (h // s, s, w // s, s) hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w) wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1) hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten() - pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1)) + pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1)) pos_ids = torch.cat(pos_ids, dim=0) max_grid_size = grid_thw[:, 1:].max() cos_full, sin_full = self.rotary_pos_emb(max_grid_size) diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index a29cb8758b..6076756043 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -57,6 +57,8 @@ def __init__( kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False) + # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup) + self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.view( diff --git a/lightllm/models/qwen2_vl/triton_kernel/mrope.py b/lightllm/models/qwen2_vl/triton_kernel/mrope.py index 5aed658626..e488a0ce30 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/mrope.py +++ b/lightllm/models/qwen2_vl/triton_kernel/mrope.py @@ -193,10 +193,11 @@ def mrope_triton_fused( sin: torch.Tensor, mrope_section: torch.Tensor, is_interleaved: bool, + partial_rotary_factor: float = 1.0, run_config: Optional[dict] = None, ): head_num_q, head_num_k = q.shape[1], k.shape[1] - head_dim = int(q.shape[2]) + head_dim = int(q.shape[2] * partial_rotary_factor) num_tokens = q.shape[0] if not run_config: diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py index f2cd38ec8e..bc313fe467 100644 --- a/lightllm/models/qwen2_vl/vision_process.py +++ b/lightllm/models/qwen2_vl/vision_process.py @@ -187,7 +187,10 @@ def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torc if image.mode != "RGB": image = image.convert("RGB") image_arr = np.asarray(image, dtype=np.uint8) - image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) + # Copy to ensure writable array (avoids PyTorch warning for read-only NumPy arrays) + image_data = ( + torch.from_numpy(image_arr.copy()).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True) + ) grouped_images, grouped_images_index = group_images_by_shape( [image_data], disable_grouping=self.disable_grouping diff --git a/lightllm/models/qwen3_5/__init__.py b/lightllm/models/qwen3_5/__init__.py new file mode 100644 index 0000000000..56a41a228a --- /dev/null +++ b/lightllm/models/qwen3_5/__init__.py @@ -0,0 +1,16 @@ +""" +Qwen3.5 Multimodal Model Module (Dense Variant) + +Provides Qwen3.5 dense multimodal model with hybrid attention and vision-language support. +For MoE variant, see qwen3_5_moe module. +""" + +from .model import ( + Qwen3_5TpPartModel, + QWen3_5Tokenizer, +) + +__all__ = [ + "Qwen3_5TpPartModel", + "QWen3_5Tokenizer", +] diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py new file mode 100644 index 0000000000..d837c4d291 --- /dev/null +++ b/lightllm/models/qwen3_5/infer_struct.py @@ -0,0 +1,17 @@ +import torch +from typing import List + +from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen35InferStateInfo(Qwen2VLInferStateInfo): + def __init__(self): + super().__init__() + self.gate_value = None + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() + return diff --git a/lightllm/models/qwen3_5/layer_infer/__init__.py b/lightllm/models/qwen3_5/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..d0657bcbe8 --- /dev/null +++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py @@ -0,0 +1,58 @@ +import torch +import torch.distributed as dist +from typing import Tuple + +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( + Qwen3NextTransformerLayerInfer, +) +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35TransformerLayerWeight, +) +from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen35TransformerLayerInfer(Qwen3NextTransformerLayerInfer): + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + # Initialize mrope section from config + rope_scaling = network_config.get("rope_scaling", {}) + mrope_section = rope_scaling.get("mrope_section", [11, 11, 10]) + self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda") + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen35TransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + + qkv_out = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv_out.split( + [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1 + ) + o_gate = layer_weight._o_gate_proj.mm(input) + + # In-place sigmoid for gate + infer_state.gate_value = o_gate.sigmoid_() + layer_weight.qk_norm_weight_( + q, + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + eps=self.eps_, + ) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + mrope_triton_fused( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + self.mrope_section, + is_interleaved=True, # Qwen3 uses interleaved mrope + partial_rotary_factor=self.partial_rotary_factor, + ) + return q, cache_kv diff --git a/lightllm/models/qwen3_5/layer_weights/__init__.py b/lightllm/models/qwen3_5/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..da93133444 --- /dev/null +++ b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py @@ -0,0 +1,92 @@ +import torch + +from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextTransformerLayerWeight, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class Qwen35TransformerLayerWeight(Qwen3NextTransformerLayerWeight): + def _init_weight_names(self): + super()._init_weight_names() + self._gate_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" + self._gate_bias_name = None + self._up_weight_name = f"model.layers.{self.layer_num_}.mlp.up_proj.weight" + self._up_bias_name = None + self._gate_up_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight" + self._gate_up_bias_name = None + self._down_weight_name = f"model.layers.{self.layer_num_}.mlp.down_proj.weight" + self._down_bias_name = None + + def _init_gdn_weight(self): + # Initialize everything from parent first, then override only linear_in_proj. + super()._init_gdn_weight() + + prefix = f"model.layers.{self.layer_num_}.linear_attn" + hidden_size = self.network_config_["hidden_size"] + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + + # NOTE: keep grouped layout directly (q, k, v, z, b, a). + self.linear_in_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[ + qk_dim, + qk_dim, + v_dim, + v_dim, + self.linear_num_v_heads, + self.linear_num_v_heads, + ], + weight_names=[ + f"{prefix}.in_proj_q.weight", + f"{prefix}.in_proj_k.weight", + f"{prefix}.in_proj_v.weight", + f"{prefix}.in_proj_z.weight", + f"{prefix}.in_proj_b.weight", + f"{prefix}.in_proj_a.weight", + ], + data_type=self.data_type_, + quant_method=self.get_quant_method("in_proj_weight"), + ) + + def _preprocess_weight(self, weights): + # Keep parent conv1d preprocessing path. + linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" + linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" + + if linear_conv1d_weight_name in weights: + weights[linear_conv1d_weight_name] = self._parse_linear_conv1d( + weights[linear_conv1d_weight_name].squeeze(1) + ) + if linear_conv1d_bias_name in weights: + weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name]) + + self._split_linear_in_proj_qkv(weights) + + def _split_linear_in_proj_qkv(self, weights): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + qkv_name = f"{prefix}.in_proj_qkv.weight" + if qkv_name not in weights: + return + + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + expected_rows = 2 * qk_dim + v_dim + + qkv = weights[qkv_name] + if qkv.shape[0] != expected_rows: + logger.warning( + f"Layer {self.layer_num_}: unexpected in_proj_qkv shape " + f"{tuple(qkv.shape)}, expected first dim {expected_rows}; skip split" + ) + return + + q, k, v = torch.split(qkv, [qk_dim, qk_dim, v_dim], dim=0) + weights[f"{prefix}.in_proj_q.weight"] = q + weights[f"{prefix}.in_proj_k.weight"] = k + weights[f"{prefix}.in_proj_v.weight"] = v + del weights[qkv_name] diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py new file mode 100644 index 0000000000..63503c77ba --- /dev/null +++ b/lightllm/models/qwen3_5/model.py @@ -0,0 +1,100 @@ +import os +import json +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import ( + Qwen35TransformerLayerWeight, +) +from lightllm.models.qwen3_vl.model import QWen3VLTokenizer +from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import ( + Qwen3VLMultimodalPreLayerInfer, +) +from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import ( + Qwen3VLPreAndPostLayerWeight, +) +from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import ( + Qwen35TransformerLayerInfer, +) +from lightllm.models.qwen3_5.infer_struct import Qwen35InferStateInfo +from lightllm.common.build_utils import repair_config +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class QWen3_5Tokenizer(QWen3VLTokenizer): + """ + Tokenizer for Qwen3.5 multimodal model. + + Inherits all multimodal tokenization logic from Qwen3VL, + including image and video token handling. + """ + + def __init__(self, tokenizer=None, image_processor=None, **kwargs): + super().__init__(tokenizer, image_processor, **kwargs) + + +@ModelRegistry(["qwen3_5"], is_multimodal=True) +class Qwen3_5TpPartModel(Qwen3NextTpPartModel): + """ + Qwen3.5 Multimodal Model (Dense Variant) + + This model combines: + - Hybrid attention from Qwen3Next (Gated Delta Networks + Full Attention) + - Multimodal capabilities from Qwen3VL (image/video processing) + - Dense MLP layers (non-MoE) + + Architecture: + - Every Nth layer uses full attention (config: full_attention_interval) + - Other layers use linear attention (Gated Delta Networks) + - Vision encoder processes images/videos before text model + - Multimodal embeddings merged with text embeddings + """ + + transformer_weight_class = Qwen35TransformerLayerWeight + pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight + + pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer + transformer_layer_infer_class = Qwen35TransformerLayerInfer + + infer_state_class = Qwen35InferStateInfo + + def _init_config(self): + config_path = os.path.join(self.weight_dir_, "config.json") + + with open(config_path, "r") as json_file: + all_config = json.load(json_file) + + self.config = all_config["text_config"] + self.vision_config = all_config.get("vision_config", None) + + if self.vision_config is None: + logger.warning("No vision_config found in checkpoint. " "Multimodal features may not work correctly.") + + # Apply standard config repairs + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + + # Qwen3.5 stores RoPE config under text_config.rope_parameters. + rope_parameters = self.config.get("rope_parameters") + if isinstance(rope_parameters, dict): + if "rope_theta" in rope_parameters and "rope_theta" not in self.config: + self.config["rope_theta"] = rope_parameters["rope_theta"] + if "partial_rotary_factor" in rope_parameters and "partial_rotary_factor" not in self.config: + self.config["partial_rotary_factor"] = rope_parameters["partial_rotary_factor"] + # Preserve the richer RoPE metadata in the expected field when absent. + if "rope_scaling" not in self.config: + self.config["rope_scaling"] = rope_parameters + + # MoE routing parameters - set defaults for Qwen3.5 compatibility + if "norm_topk_prob" not in self.config: + self.config["norm_topk_prob"] = True # Standard default for MoE models + + # Handle fine-tuning config if present + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + + # Calculate num_kv_heads for KV cache memory management + # Required by parent class _init_mem_manager() in Qwen3NextTpPartModel + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) diff --git a/lightllm/models/qwen3_5_moe/__init__.py b/lightllm/models/qwen3_5_moe/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/layer_weights/__init__.py b/lightllm/models/qwen3_5_moe/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..fe4b1883bd --- /dev/null +++ b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py @@ -0,0 +1,40 @@ +import torch +from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import Qwen35TransformerLayerWeight + + +class Qwen35MOETransformerLayerWeight(Qwen35TransformerLayerWeight): + def load_hf_weights(self, weights): + moe_intermediate_size = self.network_config_["moe_intermediate_size"] + split_fused_expert_weights(weights, self.layer_num_, moe_intermediate_size) + return super().load_hf_weights(weights) + + +def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_size: int): + layer_prefix = f"model.layers.{layer_num}." + keys = list(weights.keys()) + num_experts = 0 + + for k in keys: + if not k.startswith(layer_prefix): + continue + + if "mlp.experts.gate_up_proj" in k: + fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size] + num_experts = fused_weight.shape[0] + + prefix = k.rsplit(".gate_up_proj", 1)[0] + gate_weight = fused_weight[:, :moe_intermediate_size, :] + up_weight = fused_weight[:, moe_intermediate_size:, :] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx] + weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx] + + elif "mlp.experts.down_proj" in k: + down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size] + num_experts = down_weight.shape[0] + + prefix = k.rsplit(".down_proj", 1)[0] + + for expert_idx in range(num_experts): + weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx] diff --git a/lightllm/models/qwen3_5_moe/model.py b/lightllm/models/qwen3_5_moe/model.py new file mode 100644 index 0000000000..973274774f --- /dev/null +++ b/lightllm/models/qwen3_5_moe/model.py @@ -0,0 +1,12 @@ +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel +from lightllm.utils.log_utils import init_logger +from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import ( + Qwen35MOETransformerLayerWeight, +) + + +@ModelRegistry("qwen3_5_moe", is_multimodal=True) +class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel): + + transformer_weight_class = Qwen35MOETransformerLayerWeight diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py index b85216f22c..721893a4cd 100644 --- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py @@ -22,14 +22,14 @@ class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, network_config): - self.n_routed_experts = network_config["num_experts"] + self.n_routed_experts = network_config.get("num_experts", 0) self.is_moe = ( - network_config["num_experts"] > 0 - and layer_num not in network_config["mlp_only_layers"] - and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) - self.num_experts_per_tok = network_config["num_experts_per_tok"] - self.norm_topk_prob = network_config["norm_topk_prob"] + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 0) + self.norm_topk_prob = network_config.get("norm_topk_prob", True) super().__init__(layer_num, network_config) self.head_dim_ = network_config["head_dim"] self.tp_k_head_num_ = max(self.tp_k_head_num_, 1) diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py index 13ba6cbe0f..e525cb2d20 100644 --- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py @@ -5,11 +5,11 @@ class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight): def __init__(self, layer_num, data_type, network_config, quant_cfg=None): - self.n_routed_experts = network_config["num_experts"] + self.n_routed_experts = network_config.get("num_experts", 0) self.is_moe = ( - network_config["num_experts"] > 0 - and layer_num not in network_config["mlp_only_layers"] - and (layer_num + 1) % network_config["decoder_sparse_step"] == 0 + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 ) super().__init__(layer_num, data_type, network_config, quant_cfg) return diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py index 10a5051276..b71d7f4878 100644 --- a/lightllm/models/qwen3_moe/model.py +++ b/lightllm/models/qwen3_moe/model.py @@ -25,4 +25,6 @@ def __init__(self, kvargs): def _init_custom(self): super()._init_custom() - dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + # Only initialize DeepEP group for MoE models with num_experts + if "num_experts" in self.config and self.config["num_experts"] > 0: + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py index c20c227996..0276724749 100644 --- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py +++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py @@ -60,6 +60,8 @@ def __init__( kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup) + self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py index d389c853d5..bed8898115 100644 --- a/lightllm/models/qwen3_vl/qwen3_visual.py +++ b/lightllm/models/qwen3_vl/qwen3_visual.py @@ -29,6 +29,9 @@ from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class Qwen3VLVisionMLP(nn.Module): @@ -60,6 +63,8 @@ def __init__( kernel_size = [temporal_patch_size, patch_size, patch_size] self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True) + # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup) + self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype @@ -377,6 +382,7 @@ def encode(self, images: List[ImageItem]): image_data = read_shm(get_shm_name_data(img.uuid)) image_data = Image.open(BytesIO(image_data)) pixel_values, image_grid_thw = self.processor.preprocess(image_data) + img_tensors.append(pixel_values) img_grids.append(image_grid_thw) else: diff --git a/lightllm/models/qwen3next/__init__.py b/lightllm/models/qwen3next/__init__.py new file mode 100644 index 0000000000..a9d22c6643 --- /dev/null +++ b/lightllm/models/qwen3next/__init__.py @@ -0,0 +1,3 @@ +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel + +__all__ = ["Qwen3NextTpPartModel"] diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py new file mode 100644 index 0000000000..cd7c8d908d --- /dev/null +++ b/lightllm/models/qwen3next/infer_struct.py @@ -0,0 +1,16 @@ +import torch +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3NextInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self.gate_value = None + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) + self.b_att_seq_len = self.b_seq_len + self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous() + + return diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..ec07b38c5a --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -0,0 +1,381 @@ +import os +import torch + +import torch.distributed as dist +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextTransformerLayerWeight, +) +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl +from lightllm.utils.log_utils import init_logger +from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd +from typing import Tuple +from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule +from lightllm.distributed import all_reduce +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type +from functools import partial + +logger = init_logger(__name__) + + +class Qwen3NextTransformerLayerInfer(LlamaTransformerLayerInfer): + def __init__(self, layer_num, network_config): + self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0) + self.n_routed_experts = network_config.get("num_experts", 0) + self.is_moe = ( + network_config.get("num_experts", 0) > 0 + and layer_num not in network_config.get("mlp_only_layers", []) + and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0 + ) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1) + self.norm_topk_prob = network_config.get("norm_topk_prob", False) + + super().__init__(layer_num, network_config) + self.head_dim_ = network_config.get( + "head_dim", network_config["hidden_size"] // network_config["num_attention_heads"] + ) + num_full_attention_layers = network_config["full_attention_interval"] + self.is_linear_attention_layer = (layer_num + 1) % num_full_attention_layers != 0 + if self.is_linear_attention_layer: + self._init_linear_layer_metadata(layer_num, network_config) + return + + def _init_linear_layer_metadata(self, layer_num, network_config): + + # Linear attention specific dimensions + self.num_v_heads = network_config["linear_num_value_heads"] + self.num_k_heads = network_config["linear_num_key_heads"] + self.head_k_dim = network_config["linear_key_head_dim"] + self.head_v_dim = network_config["linear_value_head_dim"] + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_dim = network_config["linear_conv_kernel_dim"] + self.activation = network_config["hidden_act"] + + # Tensor parallelism dimensions + self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ + self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ + self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ + self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ + self.tp_key_dim = self.key_dim // self.tp_world_size_ + self.tp_value_dim = self.value_dim // self.tp_world_size_ + + assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" + self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + # SSM state dtype optimization + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + start_args = get_env_start_args() + self.ssm_state_dtype = ssm_dtype_dict.get(start_args.mamba_ssm_data_type, torch.bfloat16) + + # Pre-compute whether dtype conversion is needed + # GDN kernel output dtype is self.data_type + # Conversion needed only if SSM state uses different dtype + self.needs_ssm_dtype_conversion = get_llm_data_type() != self.ssm_state_dtype + return + + def _bind_func(self): + super()._bind_func() + self._bind_ffn() + return + + def _bind_ffn(self): + if self.is_moe: + moe_mode = os.environ.get("MOE_MODE", "TP") + if moe_mode == "EP": + self._ffn = partial(Qwen3NextTransformerLayerInfer._moe_ffn_edp, self) + else: + self._ffn = partial(Qwen3NextTransformerLayerInfer._moe_ffn, self) + else: + self._ffn = partial(Qwen3NextTransformerLayerInfer._ffn, self) + return + + def _compute_shared_expert( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + input = input.view(-1, self.embed_dim_) + shared_expert_out = super()._ffn(input, infer_state, layer_weight) + gate = layer_weight.ffn_gate.mm(input).sigmoid_() + shared_expert_out.mul_(gate) + return shared_expert_out + + def _moe_ffn( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + + hidden_states = input.view(-1, self.embed_dim_) + num_tokens, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + ) + hidden_states = hidden_states.view(num_tokens, hidden_dim) + hidden_states.add_(shared_expert_out) + return hidden_states + + def _moe_ffn_edp( + self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight) + hidden_states = input + token_num, hidden_dim = hidden_states.shape + router_logits = layer_weight.moe_gate.mm(hidden_states) + ep_output = layer_weight.experts.experts( + hidden_states, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + ) + ep_output = ep_output.view(token_num, hidden_dim) + ep_output.add_(shared_expert_out) + return ep_output + + def _get_qkv( + self, + input: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + qkv_out = layer_weight.qkv_proj.mm(input) + q, cache_kv = qkv_out.split( + [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], + dim=-1, + ) + o_gate = layer_weight._o_gate_proj.mm(input) + # In-place sigmoid saves one allocation (gate_value is consumed once in _get_o) + infer_state.gate_value = o_gate.sigmoid_() + layer_weight.qk_norm_weight_( + q, + cache_kv[:, : self.tp_k_head_num_ * self.head_dim_], + eps=self.eps_, + ) + cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, + ) + return q, cache_kv + + def _get_o( + self, + input, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ) -> torch.Tensor: + """Output projection with gating (in-place multiply to save one allocation).""" + input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) + input.mul_(infer_state.gate_value) + infer_state.gate_value = None + o_tensor = layer_weight.o_proj.mm(input) + return o_tensor + + # ==================== GDN Helper Methods ==================== + + def context_attention_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ): + if not self.is_linear_attention_layer: + return super().context_attention_forward(input_embdings, infer_state, layer_weight) + + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True) + if self.tp_world_size_ > 1: + all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return gdn_out + + def token_attention_forward( + self, + input_embdings, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ): + if not self.is_linear_attention_layer: + return super().token_attention_forward(input_embdings, infer_state, layer_weight) + gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=False) + if self.tp_world_size_ > 1: + all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return gdn_out + + def gdn_forward( + self, + input: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + is_prefill: bool, + ): + assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager) + + input = input.view(-1, self.embed_dim_) + conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_) + + mixed_qkvzba = layer_weight.linear_in_proj.mm(input) + mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill) + + if is_prefill: + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) + core_attn_out = self._gdn_prefill_kernel( + mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight + ) + else: + core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight) + + num_tokens = z.shape[0] + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) + gated_rmsnorm_forward( + core_attn_out, + layer_weight.linear_norm.weight, + None, + self.eps_, + z, + out=norm_out, + ) + core_attn_out = norm_out.view(num_tokens, -1) + output = layer_weight.linear_out_proj.mm(core_attn_out) + return output + + def _split_qkvzba(self, mixed_qkvzba, is_decode=False): + qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim + z_end = qkv_dim + self.tp_value_dim + b_end = z_end + self.tp_num_v_heads + mixed_qkv = mixed_qkvzba[:, :qkv_dim] + z = mixed_qkvzba[:, qkv_dim:z_end].view(-1, self.tp_num_v_heads, self.head_v_dim) + b = mixed_qkvzba[:, z_end:b_end] + a = mixed_qkvzba[:, b_end:] + return mixed_qkv, z, b, a + + def _rearrange_mixed_qkv(self, mixed_qkv, decode=False): + if mixed_qkv is None: + return None, None, None + if decode: + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + batch_size = mixed_qkv.shape[0] + query = query.view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + key = key.view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim) + value = value.view(batch_size, 1, self.tp_num_v_heads, self.head_v_dim) + return query, key, value + else: + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + seq_len = query.shape[0] + query = query.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + key = key.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim) + value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim) + return query, key, value + + def _gdn_prefill_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ): + """Prefill kernel for GDN forward pass.""" + # Conv1D processing + mixed_qkv = mixed_qkv.transpose(0, 1) + out_tensor = causal_conv1d_fn( + mixed_qkv, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + query_start_loc=infer_state.b1_cu_q_seq_len, + cache_indices=infer_state.b_buffer_idx, + has_initial_state=infer_state.b_ready_cache_len > 0, + conv_states=conv_states, + activation=self.activation, + ) + mixed_qkv = out_tensor.transpose(0, 1) + + # Recurrent processing + query, key, value = self._rearrange_mixed_qkv(mixed_qkv) + initial_state = ssm_states[infer_state.b_buffer_idx] + # g and beta have shape (total_tokens, num_heads), need to unsqueeze to get (1, total_tokens, num_heads) + core_attn_out, last_recurrent_state = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g.unsqueeze(0), + beta=beta.unsqueeze(0), + initial_state=initial_state, + output_final_state=True, + cu_seqlens=infer_state.b1_cu_q_seq_len, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + if self.needs_ssm_dtype_conversion: + ssm_states[infer_state.b_buffer_idx] = last_recurrent_state.to(self.ssm_state_dtype, copy=False) + else: + ssm_states[infer_state.b_buffer_idx] = last_recurrent_state + return core_attn_out + + def _gdn_decode_kernel( + self, + mixed_qkv: torch.Tensor, + conv_states: torch.Tensor, + ssm_states: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + infer_state: Qwen3NextInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ): + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states, + layer_weight.linear_conv1d.mm_param.weight, + bias=layer_weight.linear_conv1d.bias, + activation=self.activation, + conv_state_indices=infer_state.b_buffer_idx, + ) + + # Recurrent processing with fused gating + # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally + query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True) + core_attn_out, _ = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + initial_state=ssm_states, + inplace_final_state=True, + ssm_state_indices=infer_state.b_buffer_idx, + use_qk_l2norm_in_kernel=True, + A_log=layer_weight.linear_A_log.weight, + dt_bias=layer_weight.linear_dt_bias.weight, + a_raw=a, + b_raw=b, + ) + return core_attn_out diff --git a/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..daaf146907 --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,29 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, GEMMANormWeight + + +class Qwen3NextPreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.embed_tokens.weight", + data_type=self.data_type_, + ) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_ if tie_word_embeddings else None, + ) + self.final_norm_weight_ = GEMMANormWeight( + dim=hidden_size, + weight_name="model.norm.weight", + data_type=self.data_type_, + ) + return diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..31dae85ec8 --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -0,0 +1,273 @@ +import torch +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + RMSNormWeight, + GEMMANormWeight, + TpParameterWeight, + QKVROWNMMWeight, + QKGEMMANormWeight, +) + + +class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, quant_cfg=None): + num_full_attention_layers = network_config["full_attention_interval"] + self.is_linear_attention_layer = (layer_num + 1) % num_full_attention_layers != 0 + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + self.qkv_proj = QKVROWNMMWeight( + in_dim=in_dim, + q_head_num=self.q_head_num_, + kv_head_num=self.k_head_num_, + head_dim=self.head_dim, + weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name], + data_type=self.data_type_, + bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name], + quant_method=self.get_quant_method("qkv_proj"), + ) + self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + self._o_gate_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=[self._o_gate_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("o_gate_proj"), + ) + + def _init_weight(self): + if self.is_linear_attention_layer: + self._init_gdn_weight() + else: + self._init_qkv() + self._init_o() + + if self.is_moe: + self._init_moe() + else: + self._init_ffn() + self._init_norm() + + def _init_moe(self): + super()._init_moe() + self._init_gated_ffn() + return + + def _init_norm(self): + hidden_size = self.network_config_["hidden_size"] + self.att_norm_weight_ = GEMMANormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = GEMMANormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + if not self.is_linear_attention_layer: + self.qk_norm_weight_ = QKGEMMANormWeight( + dim=self.head_dim, + q_weight_name=self._q_norm_name, + k_weight_name=self._k_norm_name, + data_type=self.data_type_, + ) + + def _init_gated_ffn(self): + hidden_size = self.network_config_["hidden_size"] + if "shared_expert_intermediate_size" not in self.network_config_: + return + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + inter_size = self.network_config_["shared_expert_intermediate_size"] + self.gate_up_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[inter_size, inter_size], + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("gate_up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=inter_size, + out_dims=[hidden_size], + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("down_proj"), + ) + self.ffn_gate = ROWMMWeight( + in_dim=hidden_size, + out_dims=[1], + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_method=None, + tp_rank=0, + tp_world_size=1, + ) + + def _split_q_with_gate(self, weights): + if self._q_weight_name in weights: + weight = weights[self._q_weight_name] + num_heads = self.q_head_num_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self._q_weight_name] = _q_proj + weights[self._o_gate_weight_name] = _gate_proj + + def _parse_config(self): + super()._parse_config() + self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] + self.linear_num_k_heads = self.network_config_["linear_num_key_heads"] + self.linear_k_head_dim = self.network_config_["linear_key_head_dim"] + self.linear_v_head_dim = self.network_config_["linear_value_head_dim"] + + def _init_gdn_weight(self): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + hidden_size = self.network_config_["hidden_size"] + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + conv1d_channels = qk_dim + qk_dim + v_dim # q + k + v concatenated + kernel_size = self.network_config_.get("linear_conv_kernel_dim", 4) + + # Conv1d weight: after _preprocess_weight, shape is [channels, kernel_size]. + self.linear_conv1d = ROWMMWeight( + in_dim=kernel_size, + out_dims=[conv1d_channels], + weight_names=f"{prefix}.conv1d.weight", + data_type=self.data_type_, + quant_method=None, + ) + + # in_proj_qkvz: q(qk_dim) + k(qk_dim) + v(v_dim) + z(v_dim) + # in_proj_ba: beta(num_v_heads) + alpha(num_v_heads) — per-head scalars + qkvz_dim = qk_dim + qk_dim + v_dim + v_dim + ba_dim = self.linear_num_v_heads + self.linear_num_v_heads + self.linear_in_proj = ROWMMWeight( + in_dim=hidden_size, + out_dims=[qkvz_dim, ba_dim], + weight_names=[f"{prefix}.in_proj_qkvz.weight", f"{prefix}.in_proj_ba.weight"], + data_type=self.data_type_, + quant_method=self.get_quant_method("in_proj_weight"), + ) + + self.linear_out_proj = COLMMWeight( + in_dim=v_dim, + out_dims=[hidden_size], + weight_names=f"{prefix}.out_proj.weight", + data_type=self.data_type_, + quant_method=self.get_quant_method("out_proj_weight"), + ) + + split_n_embed = self.linear_num_v_heads // self.tp_world_size_ + self.linear_dt_bias = TpParameterWeight( + weight_name=f"{prefix}.dt_bias", + data_type=torch.float32, + split_n_embed=split_n_embed, + bias_name=None, + weight_shape=(self.linear_num_v_heads,), # Full shape before TP split + bias_shape=None, + ) + + self.linear_A_log = TpParameterWeight( + weight_name=f"{prefix}.A_log", + data_type=torch.float32, + split_n_embed=split_n_embed, + bias_name=None, + weight_shape=(self.linear_num_v_heads,), # Full shape before TP split + bias_shape=None, + ) + + # Norm is applied per-head across head_dim, not across all heads + linear_norm_dim = self.linear_v_head_dim + self.linear_norm = RMSNormWeight( + dim=linear_norm_dim, + weight_name=f"{prefix}.norm.weight", + data_type=self.data_type_, + ) + + def _preprocess_weight(self, weights): + linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" + linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" + if linear_conv1d_weight_name in weights: + # squeeze [channels, 1, kernel] -> [channels, kernel], then rearrange for TP + # Result shape: [channels, kernel_size] — matches causal_conv1d_fn's (dim, width) + weights[linear_conv1d_weight_name] = self._parse_linear_conv1d( + weights[linear_conv1d_weight_name].squeeze(1) + ) + if linear_conv1d_bias_name in weights: + weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name]) + self._rearrange_gdn_in_proj_weights(weights) + + def _rearrange_gdn_in_proj_weights(self, weights): + """Rearrange in_proj_qkvz and in_proj_ba weight rows from interleaved per-k-head layout + to TP-aware grouped layout so that after ROWMMWeight's row-slicing, each rank's + MM output is already [q_chunk, k_chunk, v_chunk, z_chunk, b_chunk, a_chunk]. + """ + num_k = self.linear_num_k_heads + k_dim = self.linear_k_head_dim + v_dim = self.linear_v_head_dim + num_v_per_k = self.linear_num_v_heads // num_k + tp = self.tp_world_size_ + + # Rearrange in_proj_qkvz + qkvz_name = f"model.layers.{self.layer_num_}.linear_attn.in_proj_qkvz.weight" + if qkvz_name in weights: + w = weights[qkvz_name] + hidden = w.shape[-1] + # Each k-head group: q(k_dim) + k(k_dim) + v(num_v_per_k * v_dim) + z(num_v_per_k * v_dim) rows + group_size = k_dim + k_dim + num_v_per_k * v_dim + num_v_per_k * v_dim + w = w.view(num_k, group_size, hidden) + v_block = num_v_per_k * v_dim + all_q = w[:, :k_dim, :].reshape(-1, hidden) # [total_q_dim, H] + all_k = w[:, k_dim : 2 * k_dim, :].reshape(-1, hidden) # [total_k_dim, H] + all_v = w[:, 2 * k_dim : 2 * k_dim + v_block, :].reshape(-1, hidden) # [total_v_dim, H] + all_z = w[:, 2 * k_dim + v_block :, :].reshape(-1, hidden) # [total_v_dim, H] + # Chunk each component by TP, interleave so row-slicing gives grouped layout per rank + q_chunks = all_q.chunk(tp, dim=0) + k_chunks = all_k.chunk(tp, dim=0) + v_chunks = all_v.chunk(tp, dim=0) + z_chunks = all_z.chunk(tp, dim=0) + weights[qkvz_name] = torch.cat( + [torch.cat([q_chunks[i], k_chunks[i], v_chunks[i], z_chunks[i]], dim=0) for i in range(tp)], + dim=0, + ) + + # Rearrange in_proj_ba + ba_name = f"model.layers.{self.layer_num_}.linear_attn.in_proj_ba.weight" + if ba_name in weights: + w = weights[ba_name] + hidden = w.shape[-1] + group_size = 2 * num_v_per_k + w = w.view(num_k, group_size, hidden) + all_b = w[:, :num_v_per_k, :].reshape(-1, hidden) # [total_num_v, H] + all_a = w[:, num_v_per_k:, :].reshape(-1, hidden) # [total_num_v, H] + b_chunks = all_b.chunk(tp, dim=0) + a_chunks = all_a.chunk(tp, dim=0) + weights[ba_name] = torch.cat( + [torch.cat([b_chunks[i], a_chunks[i]], dim=0) for i in range(tp)], + dim=0, + ) + + def _parse_linear_conv1d(self, weight): + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + q, k, v = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0) + q_splits = q.chunk(self.tp_world_size_, dim=0) + k_splits = k.chunk(self.tp_world_size_, dim=0) + v_splits = v.chunk(self.tp_world_size_, dim=0) + new_weight = torch.cat( + [torch.cat([q_splits[i], k_splits[i], v_splits[i]], dim=0) for i in range(self.tp_world_size_)], dim=0 + ) + return new_weight + + def load_hf_weights(self, weights): + self._split_q_with_gate(weights) + if self.is_linear_attention_layer: + self._preprocess_weight(weights) + super().load_hf_weights(weights) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py new file mode 100644 index 0000000000..5c8d486edf --- /dev/null +++ b/lightllm/models/qwen3next/mem_manager.py @@ -0,0 +1,154 @@ +import torch +from typing import Tuple +from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager +from lightllm.server.core.objs.start_args_type import StartArgs + +logger = init_logger(__name__) + + +class Qwen3NextHybridMemManager(MemoryManager): + @staticmethod + def calculate_mamba_cache_size( + start_args: StartArgs, + max_total_token_num: int, + mem_fraction: float, + config: dict, + head_linear_k_dim: int, + num_linear_k_heads: int, + head_linear_v_dim: int, + num_linear_v_heads: int, + tp_world_size: int, + data_type: torch.dtype, + ) -> int: + """Calculate mamba cache size based on available memory and mamba_cache_ratio.""" + from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory + import torch.distributed as dist + + use_ratio = max_total_token_num is None and start_args.mamba_cache_size is None + + world_size = dist.get_world_size() + total_memory = get_total_gpu_memory() + available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction) + + conv_kernel_size = config["linear_conv_kernel_dim"] + conv_dim = ( + head_linear_k_dim * num_linear_k_heads * 2 + head_linear_v_dim * num_linear_v_heads + ) // tp_world_size + + num_linear_layers = config["n_layer"] - (config["n_layer"] // config["full_attention_interval"]) + + conv_cell_size = num_linear_layers * conv_dim * (conv_kernel_size - 1) * torch._utils._element_size(data_type) + + ssm_dtype = torch.bfloat16 if start_args.mamba_ssm_data_type == "bfloat16" else torch.float32 + ssm_cell_size = ( + num_linear_layers + * (num_linear_v_heads // tp_world_size) + * head_linear_k_dim + * head_linear_v_dim + * torch._utils._element_size(ssm_dtype) + ) + + total_cell_size = conv_cell_size + ssm_cell_size + + if use_ratio: + # mamba_cache_ratio = mamba_memory / total_cache_memory + mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5 + mamba_memory_gb = available_memory * mamba_cache_ratio + else: + mamba_memory_gb = available_memory + mamba_cache_ratio = None + + mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size) + + if mamba_cache_size < start_args.running_max_req_size * 2: + ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5 + raise ValueError( + f"Insufficient memory for mamba cache allocation!\n\n" + f"mamba_cache_size should be at least running_max_req_size * 2\n" + f"Calculated mamba_cache_size ({mamba_cache_size}) < " + f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n\n" + f"Memory budget:\n" + f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated buffers: {mamba_cache_size}\n" + f" Required buffers: {start_args.running_max_req_size}\n\n" + f"Solutions:\n" + f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n" + f" 2. Increase --mamba_cache_ratio from {ratio} to " + f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n" + f" 3. Increase --mem_fraction to leave more memory for caches\n" + ) + + logger.info( + f"Mamba cache allocation:\n" + f" Available memory: {mamba_memory_gb:.2f} GB\n" + f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n" + f" Calculated mamba_cache_size: {mamba_cache_size}" + ) + + return mamba_cache_size + + def __init__( + self, + full_attn_cache_size, + linear_attn_cache_size, + dtype, + num_kv_heads, + head_dim, + layer_num, + mtp_layer_num, + full_attention_interval: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + max_req_num: int, + always_copy=False, + mem_fraction=0.9, + ): + + self.full_attention_interval = full_attention_interval + assert layer_num % full_attention_interval == 0 + self.layer_num = layer_num + self.mtp_layer_num = mtp_layer_num + self.full_attn_layer_num = layer_num // full_attention_interval + self.linear_attn_layer_num = layer_num - self.full_attn_layer_num + + self.mamba_cache_mem_manager = MambaCacheManager( + linear_attn_cache_size, + self.linear_attn_layer_num, + conv_state_dtype, + conv_state_shape, + ssm_state_dtype, + ssm_state_shape, + ) + + super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ..., + # None, kv_cache, mtp_kv_cache, mtp_kv_cache] + # Only full attention layers and MTP layers have KV cache. + self.kv_buffer = [None for _ in range(self.layer_num)] + for layer_id in range(self.full_attn_layer_num): + self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( + (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" + ) + for _ in range(self.mtp_layer_num): + self.kv_buffer.append(torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")) + + def free_all(self): + super().free_all() + self.mamba_cache_mem_manager.free_all() + return + + def get_cell_size(self): + # Only full attention layers and MTP layers have KV cache + kv_cache_layer_num = self.full_attn_layer_num + self.mtp_layer_num + return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype) + + def get_mamba_cache(self, layer_idx: int): + layer_idx_in_linear = layer_idx - (layer_idx // self.full_attention_interval) + return self.mamba_cache_mem_manager.get_mamba_cache(layer_idx_in_linear) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py new file mode 100644 index 0000000000..b00f57f3ec --- /dev/null +++ b/lightllm/models/qwen3next/model.py @@ -0,0 +1,139 @@ +import torch +from typing import Optional +import triton +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import ( + Qwen3NextTransformerLayerWeight, +) +from lightllm.models.qwen3next.layer_weights.pre_and_post_layer_weight import Qwen3NextPreAndPostLayerWeight +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import ( + Qwen3NextTransformerLayerInfer, +) +from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.common.req_manager import ReqManagerForMamba +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache + +logger = init_logger(__name__) + + +@ModelRegistry("qwen3_next") +class Qwen3NextTpPartModel(Qwen3MOEModel): + + # weight class + pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight + transformer_weight_class = Qwen3NextTransformerLayerWeight + + # infer class + transformer_layer_infer_class = Qwen3NextTransformerLayerInfer + + # infer state class + infer_state_class = Qwen3NextInferStateInfo + + def get_radix_class(self): + return HybridRadixCache + + def __init__(self, kvargs) -> None: + self.mem_manager: Qwen3NextHybridMemManager = None + + def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch.Tensor: + return torch.empty(size, device="cuda", dtype=torch.int8) + + # Set Triton allocator for TMA descriptors + # This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py + triton.set_allocator(_triton_allocator) + logger.info("Triton allocator set for Qwen3Next model") + super().__init__(kvargs) + + def autotune_layers(self): + return self.config["full_attention_interval"] + + def _init_config(self): + super()._init_config() + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + def _init_custom(self): + super()._init_custom() + # Only initialize DeepEP group for MoE models with num_experts + if "num_experts" in self.config and self.config["num_experts"] > 0: + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + + def _init_mem_manager(self): + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + + start_args: StartArgs = get_env_start_args() + mamba_cache_size = start_args.mamba_cache_size + + self.num_linear_k_heads = self.config["linear_num_key_heads"] + self.num_linear_v_heads = self.config["linear_num_value_heads"] + self.head_linear_k_dim = self.config["linear_key_head_dim"] + self.head_linear_v_dim = self.config["linear_value_head_dim"] + + if mamba_cache_size is None: + mamba_cache_size = Qwen3NextHybridMemManager.calculate_mamba_cache_size( + start_args=start_args, + max_total_token_num=self.max_total_token_num, + mem_fraction=self.mem_fraction, + config=self.config, + head_linear_k_dim=self.head_linear_k_dim, + num_linear_k_heads=self.num_linear_k_heads, + head_linear_v_dim=self.head_linear_v_dim, + num_linear_v_heads=self.num_linear_v_heads, + tp_world_size=self.tp_world_size_, + data_type=self.data_type, + ) + else: + if mamba_cache_size < start_args.running_max_req_size * 2: + raise ValueError( + f"Explicitly set mamba_cache_size ({mamba_cache_size}) < " + f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n" + f"Please increase mamba_cache_size to at least {start_args.running_max_req_size * 2}" + ) + + conv_kernel_size = self.config["linear_conv_kernel_dim"] + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) + + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + if start_args.mamba_ssm_data_type not in ssm_dtype_dict: + raise ValueError( + f"Invalid mamba_ssm_data_type: {start_args.mamba_ssm_data_type}." + f" Must be one of {list(ssm_dtype_dict.keys())}" + ) + + self.mem_manager = Qwen3NextHybridMemManager( + full_attn_cache_size=self.max_total_token_num, + linear_attn_cache_size=mamba_cache_size, + dtype=self.data_type, + num_kv_heads=self.num_kv_heads, + head_dim=self.config["head_dim"], + layer_num=self.config["n_layer"], + mtp_layer_num=start_args.mtp_step, + full_attention_interval=self.config["full_attention_interval"], + conv_state_dtype=self.data_type, + conv_state_shape=(conv_dim // self.tp_world_size_, conv_kernel_size - 1), + ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type], + ssm_state_shape=( + self.num_linear_v_heads // self.tp_world_size_, + self.head_linear_k_dim, + self.head_linear_v_dim, + ), + max_req_num=self.max_req_num, + mem_fraction=self.mem_fraction, + ) + + def _init_req_manager(self): + create_max_seq_len = 0 + + if self.batch_max_tokens is not None: + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) + if self.max_seq_length is not None: + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) + + self.req_manager = ReqManagerForMamba(self.max_req_num, create_max_seq_len, self.mem_manager) diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py new file mode 100644 index 0000000000..c6d099a2d8 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py @@ -0,0 +1,122 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d.py + +from typing import Optional + +import torch + +from sgl_kernel import causal_conv1d_fwd +from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = -1, + **kwargs, +): + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen + sequences are concatenated from left to right for varlen + weight: (dim, width) + bias: (dim,) + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + + causal_conv1d_fwd( + x, + weight, + bias, + conv_states, + query_start_loc, + cache_indices, + has_initial_state, + activation in ["silu", "swish"], + pad_slot_id, + ) + return x + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Optional[str] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + pad_slot_id: int = -1, +): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError(f"activation must be None, silu, or swish, actual: {activation}") + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + causal_conv1d_update_kernel( + x, + conv_state, + weight, + bias, + activation_val, + cache_seqlens, + conv_state_indices, + pad_slot_id, + ) + if unsqueeze: + x = x.squeeze(-1) + return x diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py new file mode 100644 index 0000000000..2bde70bb99 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py @@ -0,0 +1,11 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Adapted from +# https://github.com/vllm-project/vllm diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py new file mode 100644 index 0000000000..cd3b0962a3 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py new file mode 100644 index 0000000000..7b3067bbfb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py @@ -0,0 +1,224 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + chunk_size=64, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_size=64, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: torch.LongTensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: torch.LongTensor | None = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2, + ) + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py new file mode 100644 index 0000000000..97933b2ac2 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -0,0 +1,324 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp, safe_exp +from lightllm.common.triton_utils.autotuner import autotune + +NUM_WARPS = [2, 4, 8, 16] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + if SAVE_NEW_VALUE: + v_new += ((bos * H + i_h) * V).to(tl.int64) + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v = b_v * safe_exp(b_g_last - b_g)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k1, + mask=(o_k1 < K), + other=0.0, + ) + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k2, + mask=(o_k2 < K), + other=0.0, + ) + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k3, + mask=(o_k3 < K), + other=0.0, + ) + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k4, + mask=(o_k4 < K), + other=0.0, + ) + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v) + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_delta_h_configs(): + return [ + {"BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ] + + +def _get_chunk_delta_h_static_key(k, u, chunk_size): + B, T, Hg, K = k.shape + V = u.shape[-1] + H = u.shape[-2] + return {"H": H, "K": K, "V": V, "BT": chunk_size} + + +def _get_chunk_delta_h_run_key(k, u): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] + + +@autotune( + kernel_name="chunk_gated_delta_rule_fwd_h", + configs_gen_func=_get_chunk_delta_h_configs, + static_key_func=_get_chunk_delta_h_static_key, + run_key_func=_get_chunk_delta_h_run_key, +) +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: torch.LongTensor | None = None, + run_config=None, +) -> tuple[torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + # Extract config parameters + if run_config is None: + run_config = {"BV": 64, "num_warps": 2, "num_stages": 2} + + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, + ) + return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py new file mode 100644 index 0000000000..fc49763ecd --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -0,0 +1,205 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp, safe_exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper +from lightllm.common.triton_utils.autotuner import autotune + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_o_configs(): + return [ + {"BK": BK, "BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_o_static_key(q, v, chunk_size): + B, T, Hg, K = q.shape + V = v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + return {"H": H, "K": K, "V": V, "BT": BT} + + +def _get_chunk_o_run_key(q, v): + # Return batch * heads as run key + return q.shape[0] * q.shape[2] + + +@autotune( + kernel_name="chunk_fwd_o", + configs_gen_func=_get_chunk_o_configs, + static_key_func=_get_chunk_o_static_key, + run_key_func=_get_chunk_o_run_key, +) +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor | None = None, # cumsum of log decay + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + run_config=None, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "BV": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, + ) + return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 0000000000..60a594c078 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,180 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import safe_exp +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_A += tl.dot(b_k, tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * safe_exp(b_g_diff) + + b_A *= b_beta[:, None] + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_chunk_scaled_dot_kkt_configs(): + return [ + {"BK": BK, "num_warps": num_warps, "num_stages": num_stages} + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_scaled_dot_kkt_static_key(k, beta, chunk_size=64, cu_seqlens=None): + B, T, Hg, K = k.shape + H = beta.shape[-1] + IS_VARLEN = cu_seqlens is not None + return {"H": H, "K": K, "BT": chunk_size, "IS_VARLEN": IS_VARLEN} + + +def _get_chunk_scaled_dot_kkt_run_key(k, beta): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] + + +@autotune( + kernel_name="chunk_scaled_dot_kkt_fwd", + configs_gen_func=_get_chunk_scaled_dot_kkt_configs, + static_key_func=_get_chunk_scaled_dot_kkt_static_key, + run_key_func=_get_chunk_scaled_dot_kkt_run_key, +) +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + g: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, + run_config=None, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. + B, T, Hg, K = k.shape + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + g=g, + beta=beta, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + BK=BK, + num_warps=num_warps, + num_stages=num_stages, + ) + return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py new file mode 100644 index 0000000000..6331e1602d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -0,0 +1,306 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard +from lightllm.common.triton_utils.autotuner import autotune + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + else: + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def _get_cumsum_scalar_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8]] + + +def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first): + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} + + +def _get_cumsum_scalar_run_key(g): + # Return total number of elements as run key + return g.shape[0] * g.shape[1] + + +@autotune( + kernel_name="chunk_local_cumsum_scalar", + configs_gen_func=_get_cumsum_scalar_configs, + static_key_func=_get_cumsum_scalar_static_key, + run_key_func=_get_cumsum_scalar_run_key, +) +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + run_config=None, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"num_warps": 2} + + num_warps = run_config.get("num_warps", 2) + + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=num_warps, + ) + return g + + +def _get_cumsum_vector_configs(): + return [{"BS": BS, "num_warps": num_warps} for BS in BS_LIST for num_warps in [2, 4, 8]] + + +def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first): + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "S": S, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} + + +def _get_cumsum_vector_run_key(g): + # Return batch * heads as run key + return g.shape[0] * g.shape[2] if len(g.shape) == 4 else g.shape[0] + + +@autotune( + kernel_name="chunk_local_cumsum_vector", + configs_gen_func=_get_cumsum_vector_configs, + static_key_func=_get_cumsum_vector_static_key, + run_key_func=_get_cumsum_vector_run_key, +) +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + run_config=None, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"BS": 32, "num_warps": 2} + + BS = run_config.get("BS", 32) + num_warps = run_config.get("num_warps", 2) + + grid = (triton.cdiv(S, BS), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse, + num_warps=num_warps, + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: torch.Tensor | None = None, + head_first: bool = False, + output_dtype: torch.dtype | None = torch.float, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + else: + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py new file mode 100644 index 0000000000..22a93a2c99 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -0,0 +1,492 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .op import exp + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + "HAS_SEPARATE_WRITE_INDICES": lambda args: args["ssm_state_write_indices"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + ssm_state_write_indices, # NEW: separate write indices for state propagation optimization + num_accepted_tokens, + # Fused gating parameters (only used when FUSE_GATING=True) + A_log, # [HV] per-head log decay + dt_bias, # [HV] per-head dt bias + a_raw, # [B*T, HV] raw alpha values (before softplus) + b_raw, # [B*T, HV] raw beta values (before sigmoid) + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + stride_write_indices_seq: tl.constexpr, # NEW: stride for write indices + stride_write_indices_tok: tl.constexpr, # NEW: stride for write indices + SOFTPLUS_BETA: tl.constexpr, # softplus beta parameter (default 1.0) + SOFTPLUS_THRESHOLD: tl.constexpr, # softplus threshold (default 20.0) + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, + HAS_SEPARATE_WRITE_INDICES: tl.constexpr, # NEW: whether to use separate write indices + FUSE_GATING: tl.constexpr, # whether to compute g/beta inline from raw values +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if FUSE_GATING: + # Fused gating: load per-head constants once, compute g/beta inline per token + b_A_log = tl.load(A_log + i_hv).to(tl.float32) + b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32) + p_a_raw = a_raw + bos * HV + i_hv + p_b_raw = b_raw + bos * HV + i_hv + else: + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token + ) + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + if FUSE_GATING: + # Compute g = -exp(A_log) * softplus(a_raw + dt_bias) inline + b_a = tl.load(p_a_raw).to(tl.float32) + x = b_a + b_dt_bias + softplus_x = tl.where( + SOFTPLUS_BETA * x <= SOFTPLUS_THRESHOLD, + (1.0 / SOFTPLUS_BETA) * tl.log(1.0 + tl.exp(SOFTPLUS_BETA * x)), + x, + ) + b_g = -tl.exp(b_A_log) * softplus_x + b_h *= exp(b_g) + # Compute beta = sigmoid(b_raw) inline + b_b = tl.load(p_b_raw).to(tl.float32) + b_beta = tl.sigmoid(b_b) + else: + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + # Use separate write indices if provided (for state propagation optimization) + # Otherwise fall back to read indices + if HAS_SEPARATE_WRITE_INDICES: + write_idx = tl.load(ssm_state_write_indices + i_n * stride_write_indices_seq + i_t).to(tl.int64) + else: + write_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) + p_ht = ht + write_idx * stride_final_state_token + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + if FUSE_GATING: + p_a_raw += HV + p_b_raw += HV + else: + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + # Fused gating parameters + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK = triton.next_power_of_2(K) + if T == 1: + # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16) + # and more warps for better SM utilization at T=1 where there's no pipelining benefit + BV = min(triton.next_power_of_2(V), 32) + num_warps = 4 + num_stages = 1 + else: + # Prefill path: small BV for better pipelining across sequence length + BV = min(triton.next_power_of_2(V), 8) + num_warps = 1 + num_stages = 3 + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + + fuse_gating = A_log is not None + + if out is not None: + o = out.unsqueeze(0) if out.ndim == v.ndim else out + else: + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + # Strides for read indices + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + # Strides for write indices (if provided) + if ssm_state_write_indices is None: + stride_write_indices_seq, stride_write_indices_tok = 1, 1 + elif ssm_state_write_indices.ndim == 1: + stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1 + else: + stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + ssm_state_write_indices=ssm_state_write_indices, + num_accepted_tokens=num_accepted_tokens, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw, + b_raw=b_raw, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + stride_write_indices_seq=stride_write_indices_seq, + stride_write_indices_tok=stride_write_indices_tok, + SOFTPLUS_BETA=1.0, + SOFTPLUS_THRESHOLD=20.0, + IS_BETA_HEADWISE=False if fuse_gating else (beta.ndim == v.ndim), + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, + FUSE_GATING=fuse_gating, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous() if g is not None else None, + beta=beta.contiguous() if beta is not None else None, + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + ssm_state_write_indices=ssm_state_write_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + A_log=A_log, + dt_bias=dt_bias, + a_raw=a_raw.contiguous() if a_raw is not None else None, + b_raw=b_raw.contiguous() if b_raw is not None else None, + out=out, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor = None, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices for state propagation + num_accepted_tokens: torch.Tensor | None = None, + use_qk_l2norm_in_kernel: bool = False, + # Fused gating: pass raw values to compute g/beta inline in the kernel + A_log: torch.Tensor | None = None, + dt_bias: torch.Tensor | None = None, + a_raw: torch.Tensor | None = None, + b_raw: torch.Tensor | None = None, + out: torch.Tensor | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + fuse_gating = A_log is not None + if not fuse_gating and beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + ssm_state_write_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + A_log, + dt_bias, + a_raw, + b_raw, + out, + ) + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py new file mode 100644 index 0000000000..8b1d59fc63 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py new file mode 100644 index 0000000000..29f892ef26 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +import torch + +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def _get_l2norm_kernel1_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16, 32]] + + +def _get_l2norm_kernel1_static_key(x): + D = x.shape[-1] + return {"D": D} + + +def _get_l2norm_kernel1_run_key(x): + return x.shape[0] # T + + +@autotune( + kernel_name="l2norm_fwd_kernel1", + configs_gen_func=_get_l2norm_kernel1_configs, + static_key_func=_get_l2norm_kernel1_static_key, + run_key_func=_get_l2norm_kernel1_run_key, +) +def _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD, run_config=None): + if run_config is None: + run_config = {"num_warps": 4} + + num_warps = run_config.get("num_warps", 4) + T = x.shape[0] + + l2norm_fwd_kernel1[(T,)](x, y, eps=eps, D=D, BD=BD, num_warps=num_warps) + + +def _get_l2norm_kernel_configs(): + return [{"BT": BT, "num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST] + + +def _get_l2norm_kernel_static_key(x): + D = x.shape[-1] + return {"D": D} + + +def _get_l2norm_kernel_run_key(x): + return x.shape[0] # T + + +@autotune( + kernel_name="l2norm_fwd_kernel", + configs_gen_func=_get_l2norm_kernel_configs, + static_key_func=_get_l2norm_kernel_static_key, + run_key_func=_get_l2norm_kernel_run_key, +) +def _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB, run_config=None): + if run_config is None: + run_config = {"BT": 32, "num_warps": 4} + + BT = run_config.get("BT", 32) + num_warps = run_config.get("num_warps", 4) + + grid = (triton.cdiv(T, BT),) + l2norm_fwd_kernel[grid](x, y, eps, NB=NB, T=T, D=D, BT=BT, BD=BD, num_warps=num_warps) + + +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB) + else: + _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD) + + return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py new file mode 100644 index 0000000000..2f69aa981d --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py @@ -0,0 +1,65 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import triton +import triton.language as tl + +from .utils import is_gather_supported + +exp = tl.exp +log = tl.log +log2 = tl.log2 + + +@triton.jit +def safe_exp(x): + """ + Numerically stable exponential function. + Only applies exp to non-positive values, returns 0 for positive values. + This prevents numerical overflow and improves stability. + """ + return exp(tl.where(x <= 0, x, float("-inf"))) + + +if not is_gather_supported: + + @triton.jit + def gather(src, index, axis, _builder=None): + """ + Gather operation that works when tl.gather is not supported. + This is a fallback implementation that returns None. + Just to make triton compiler happy. + """ + return None + +else: + gather = tl.gather + +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, "make_tensor_descriptor"): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py new file mode 100644 index 0000000000..b5b6cfc369 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -0,0 +1,462 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 + +import os +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import make_tensor_descriptor +from .utils import input_guard, is_amd, is_tma_supported + + +def _ensure_triton_allocator(): + """Ensure Triton has an allocator set for kernels requiring scratch memory.""" + + def alloc_fn(size: int, alignment: int, stream: Optional[int]): + return torch.empty(size, device="cuda", dtype=torch.int8) + + triton.set_allocator(alloc_fn) + + +FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") +ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"] +assert ( + FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS +), f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}" + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + A = A + (bos * H + i_h) * BT + Ai = Ai + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + if not USE_TMA: + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + # [16, 16] + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16]) + b_A = desc.load([i_t * 16, offset]).to(tl.float32) + b_A = -tl.where(m_A, b_A, 0) + + for i in range(2, min(16, T - i_t * 16)): + # [16] + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + b_A = tl.where((o_i == i)[:, None], b_a, b_A) + b_A += m_I + if not USE_TMA: + p_Ai = tl.make_block_ptr(Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + + b_Ai_11 += m_I + b_Ai_22 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_A_33 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_A_44 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) + b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) + b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) + b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + for i in range(32 + 2, min(48, T - i_t * BT)): + b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32) + b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) + b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) + for i in range(48 + 2, min(64, T - i_t * BT)): + b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48) + b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) + b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) + b_Ai_11 += m_I + b_Ai_22 += m_I + b_Ai_33 += m_I + b_Ai_44 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) + b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) + b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) + b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) + b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + b_Ai_32 = -tl.dot( + tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), + b_Ai_22, + input_precision=DOT_PRECISION, + ) + b_Ai_43 = -tl.dot( + tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), + b_Ai_33, + input_precision=DOT_PRECISION, + ) + + b_Ai_31 = -tl.dot( + b_Ai_33, + tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_42 = -tl.dot( + b_Ai_44, + tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_41 = -tl.dot( + b_Ai_44, + tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")) + + +@input_guard +def solve_tril( + A: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + """ + Compute the inverse of the matrix I + A + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, BT], where BT should only be 16, 32, or 64. + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. Default: `None`. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + output_dtype = A.dtype if output_dtype is None else output_dtype + + B, T, H, BT = A.shape + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + + Ai = torch.zeros_like(A, dtype=output_dtype) + if BT == 16: + merge_fn = solve_tril_16x16_kernel + elif BT == 32: + merge_fn = merge_16x16_to_32x32_inverse_kernel + elif BT == 64: + merge_fn = merge_16x16_to_64x64_inverse_kernel + + # Ensure Triton allocator is set for TMA kernels that require scratch memory + if is_tma_supported: + _ensure_triton_allocator() + + merge_fn[NT, B * H]( + A=A, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + USE_TMA=is_tma_supported, + DOT_PRECISION=FLA_TRIL_PRECISION, + ) + return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py new file mode 100644 index 0000000000..cd7c2e3aeb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py @@ -0,0 +1,179 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from collections.abc import Callable +from enum import Enum +from typing import Any, Literal + +import torch + +import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[tuple | None, dict | None, Any] = [] + cache_size = 8 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cpu" + + +@functools.cache +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = "cuda" +device_torch_lib = getattr(torch, device, None) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = True +is_gather_supported = hasattr(triton.language, "gather") +is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") +) + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py new file mode 100644 index 0000000000..08bb00e644 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: torch.LongTensor | None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py new file mode 100644 index 0000000000..6413158a66 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py @@ -0,0 +1,186 @@ +import torch + +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _fused_add_gemma_rmsnorm_kernel( + x_ptr, + r_ptr, + w_ptr, + y_ptr, + x_stride0, + x_stride1, + r_stride0, + r_stride1, + y_stride0, + y_stride1, + N: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Fused in-place residual add + Gemma RMSNorm. + + For each row: + 1. sum = x + residual (written back to x in-place) + 2. rstd = 1 / sqrt(mean(sum²) + eps) + 3. y = sum * rstd * (w + 1.0) (Gemma-style) + """ + row = tl.program_id(0) + x_ptr = x_ptr + row * x_stride0 + r_ptr = r_ptr + row * r_stride0 + y_ptr = y_ptr + row * y_stride0 + + # Pass 1: compute sum = x + residual, write back to x, accumulate sum² for variance + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + r = tl.load(r_ptr + cols * r_stride1, mask=mask, other=0.0).to(tl.float32) + s = x + r + # Write sum back to x (in-place residual add) + tl.store(x_ptr + cols * x_stride1, s.to(x_ptr.dtype.element_ty), mask=mask) + _var += s * s + + var = tl.sum(_var, axis=0) / N + rstd = 1.0 / tl.sqrt(var + EPS) + + # Pass 2: normalize and apply Gemma-style linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + # Re-read x (now contains sum); hot in L2 from the write in pass 1 + s = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) + y = s * rstd * (w + 1.0) + tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) + + +def _get_fused_add_gemma_rmsnorm_configs(): + """Generate configurations for autotuning fused add + Gemma RMSNorm kernel.""" + configs = [] + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: + for num_warps in [1, 2, 4, 8]: + configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1}) + return configs + + +def _get_fused_add_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + N = x.shape[-1] + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(w.dtype), + "N": N, + } + + +@autotune( + kernel_name="fused_add_gemma_rmsnorm:v1", + configs_gen_func=_get_fused_add_gemma_rmsnorm_configs, + static_key_func=_get_fused_add_gemma_rmsnorm_static_key, + run_key_func=lambda x: x.shape[-1], + mutates_args=["x"], +) +def fused_add_gemma_rmsnorm(x, residual, w, eps, out=None, run_config: dict = None): + """Fused in-place residual add + Gemma RMSNorm. + + x: [M, N] - modified in-place (x += residual) + residual: [M, N] - residual to add (will be viewed as [-1, N]) + w: [N] - norm weight (Gemma-style: applies w + 1.0) + eps: float + out: [M, N] - output buffer (allocated if None) + Returns: out + """ + N = x.shape[-1] + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, N) + r_arg = residual.view(-1, N) + y_arg = y.view(-1, N) + + M = x_arg.shape[0] + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This fused_add_gemma_rmsnorm doesn't support feature dim >= 64KB.") + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} + + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + _fused_add_gemma_rmsnorm_kernel[(M,)]( + x_arg, + r_arg, + w, + y_arg, + x_stride0=x_arg.stride(0), + x_stride1=x_arg.stride(1), + r_stride0=r_arg.stride(0), + r_stride1=r_arg.stride(1), + y_stride0=y_arg.stride(0), + y_stride1=y_arg.stride(1), + N=N, + EPS=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y + + +def _fused_add_gemma_rmsnorm_torch(x, residual, weight, eps): + """Reference implementation for correctness testing.""" + original_dtype = x.dtype + x = x.to(torch.float32) + residual = residual.to(torch.float32) + s = x + residual + normed = s * torch.rsqrt(s.pow(2).mean(-1, keepdim=True) + eps) + out = normed * (1.0 + weight.float()) + return s.to(original_dtype), out.to(original_dtype) + + +def test_fused_add_gemma_rmsnorm(M=128, N=2048, dtype=torch.bfloat16, eps=1e-5, device="cuda"): + """Verify fused kernel matches separate add + gemma_rmsnorm.""" + x_shape = (M, N) + w_shape = (N,) + weight = torch.rand(w_shape, dtype=dtype, device=device) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device) + residual = 0.1 * torch.randn(x_shape, dtype=dtype, device=device) + + # Clone x for reference (since fused modifies x in-place) + x_ref = x.clone() + x_fused = x.clone() + + # Reference: separate add + norm + x_ref_sum, y_ref = _fused_add_gemma_rmsnorm_torch(x_ref, residual, weight, eps) + + # Fused kernel + y_fused = fused_add_gemma_rmsnorm(x_fused, residual, weight, eps) + + # Check x was modified in-place (x += residual) + print(f"Test: M={M}, N={N}, dtype={dtype}") + print(f" x in-place max delta: {torch.max(torch.abs(x_fused - x_ref_sum)):.6e}") + print(f" output max delta: {torch.max(torch.abs(y_fused - y_ref)):.6e}") + + atol = 1e-2 if dtype == torch.float32 else 5e-2 + assert torch.allclose(x_fused, x_ref_sum, atol=atol, rtol=0), "x in-place update mismatch!" + assert torch.allclose(y_fused, y_ref, atol=atol, rtol=0), "output mismatch!" + print(" PASSED") + + +if __name__ == "__main__": + test_fused_add_gemma_rmsnorm(M=1, N=2048) + test_fused_add_gemma_rmsnorm(M=128, N=2048) + test_fused_add_gemma_rmsnorm(M=1, N=2048, dtype=torch.float16) + test_fused_add_gemma_rmsnorm(M=64, N=4096, dtype=torch.float32) + print("All tests passed!") diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py new file mode 100644 index 0000000000..88febaffc6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -0,0 +1,93 @@ +# Adapted from https://github.com/sgl-project/sglang/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py +from typing import Optional, Tuple + +import torch +import triton +import triton.language as tl + +from lightllm.common.triton_utils.autotuner import autotune + + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +# beta_output = b.sigmoid() +@triton.jit +def fused_gdn_gating_kernel( + g, + beta_output, + A_log, + a, + b, + dt_bias, + stride_a_row, + stride_b_row, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_d = tl.program_id(0), tl.program_id(1) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * NUM_HEADS + head_off + off_a = i_b * stride_a_row + head_off + off_b = i_b * stride_b_row + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off_a, mask=mask) + blk_b = tl.load(b + off_b, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask) + + +def _get_fused_gdn_gating_configs(): + return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [4, 8, 16, 32, 64] for nw in [1, 2, 4]] + + +def _get_fused_gdn_gating_static_key(a: torch.Tensor): + # group by head size and input dtype + return {"NUM_HEADS": a.shape[1], "a_dtype": str(a.dtype)} + + +@autotune( + kernel_name="fused_gdn_gating:v1", + configs_gen_func=_get_fused_gdn_gating_configs, + static_key_func=_get_fused_gdn_gating_static_key, + run_key_func=lambda a: a.shape[0], +) +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + + if run_config is None: + run_config = {"BLK_HEADS": 8, "num_warps": 1} + + batch, num_heads = a.shape + grid = (batch, triton.cdiv(num_heads, run_config["BLK_HEADS"])) + g = torch.empty(batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(batch, num_heads, dtype=torch.float32, device=a.device) + fused_gdn_gating_kernel[grid]( + g, + beta_output, + A_log, + a, + b, + dt_bias, + a.stride(0), + b.stride(0), + num_heads, + beta, + threshold, + run_config["BLK_HEADS"], + num_warps=run_config["num_warps"], + ) + return g, beta_output diff --git a/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py new file mode 100644 index 0000000000..f37d4911af --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py @@ -0,0 +1,163 @@ +""" +Fused QKV projection and GDN gating computation. + +This kernel fuses: +1. Linear projection (matmul with weight) +2. Output reorganization (split and reshape) +3. Gating computation (g and beta from a, b) + +This reduces kernel launches from 3 to 1 for the QKV+gating path. +""" + +import torch +import triton +import triton.language as tl +from typing import Tuple, Optional +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _fused_gdn_gating_only_kernel( + # Output pointers + g_ptr, + beta_ptr, + # Input pointers + a_ptr, + b_ptr, + A_log_ptr, + dt_bias_ptr, + # Dimensions + batch_size, + num_heads, + # Constants + beta_const: tl.constexpr, + threshold: tl.constexpr, + BLOCK_BATCH: tl.constexpr, + BLOCK_HEADS: tl.constexpr, +): + """ + Fused kernel for GDN gating computation with better memory access patterns. + + Computes: + - g = -exp(A_log) * softplus(a + dt_bias) + - beta = sigmoid(b) + """ + pid_batch = tl.program_id(0) + pid_head = tl.program_id(1) + + batch_offs = pid_batch * BLOCK_BATCH + tl.arange(0, BLOCK_BATCH) + head_offs = pid_head * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS) + + batch_mask = batch_offs < batch_size + head_mask = head_offs < num_heads + mask = batch_mask[:, None] & head_mask[None, :] + + # Load A_log and dt_bias (broadcast across batch) + A_log = tl.load(A_log_ptr + head_offs, mask=head_mask, other=0.0) + dt_bias = tl.load(dt_bias_ptr + head_offs, mask=head_mask, other=0.0) + + # Load a and b + offs = batch_offs[:, None] * num_heads + head_offs[None, :] + a = tl.load(a_ptr + offs, mask=mask, other=0.0) + b = tl.load(b_ptr + offs, mask=mask, other=0.0) + + # Compute g = -exp(A_log) * softplus(a + dt_bias) + x = a.to(tl.float32) + dt_bias.to(tl.float32) + softplus_x = tl.where(beta_const * x <= threshold, (1.0 / beta_const) * tl.log(1.0 + tl.exp(beta_const * x)), x) + g = -tl.exp(A_log.to(tl.float32)) * softplus_x + + # Compute beta = sigmoid(b) + beta_out = tl.sigmoid(b.to(tl.float32)) + + # Store outputs with layout [1, batch, num_heads] + out_offs = batch_offs[:, None] * num_heads + head_offs[None, :] + tl.store(g_ptr + out_offs, g.to(g_ptr.dtype.element_ty), mask=mask) + tl.store(beta_ptr + out_offs, beta_out.to(beta_ptr.dtype.element_ty), mask=mask) + + +def _get_fused_gating_configs(): + """Generate autotuning configurations.""" + configs = [] + for block_batch in [1, 4, 8, 16]: + for block_heads in [8, 16, 32]: + for num_warps in [2, 4, 8]: + configs.append( + { + "BLOCK_BATCH": block_batch, + "BLOCK_HEADS": block_heads, + "num_warps": num_warps, + } + ) + return configs + + +def _get_fused_gating_static_key(a: torch.Tensor): + return {"dtype": str(a.dtype), "num_heads": a.shape[1]} + + +def _get_fused_gating_run_key(a: torch.Tensor): + return a.shape[0] + + +@autotune( + kernel_name="fused_gdn_gating_v2:v1", + configs_gen_func=_get_fused_gating_configs, + static_key_func=_get_fused_gating_static_key, + run_key_func=_get_fused_gating_run_key, + mutates_args=["g", "beta"], +) +def fused_gdn_gating_v2( + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + beta_const: float = 1.0, + threshold: float = 20.0, + run_config: Optional[dict] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Optimized GDN gating with pre-allocated output tensors. + + Args: + a: Input tensor [batch, num_heads] + b: Input tensor [batch, num_heads] + A_log: Log of A parameter [num_heads] + dt_bias: Bias for dt [num_heads] + g: Output tensor [1, batch, num_heads] (pre-allocated) + beta: Output tensor [1, batch, num_heads] (pre-allocated) + beta_const: Beta constant for softplus (default: 1.0) + threshold: Threshold for softplus approximation (default: 20.0) + run_config: Optional autotuning configuration + + Returns: + Tuple of (g, beta) - same tensors passed in, now filled + """ + batch_size, num_heads = a.shape + + if run_config is None: + run_config = {"BLOCK_BATCH": 8, "BLOCK_HEADS": 16, "num_warps": 4} + + grid = ( + triton.cdiv(batch_size, run_config["BLOCK_BATCH"]), + triton.cdiv(num_heads, run_config["BLOCK_HEADS"]), + ) + + _fused_gdn_gating_only_kernel[grid]( + g, + beta, + a, + b, + A_log, + dt_bias, + batch_size, + num_heads, + beta_const, + threshold, + run_config["BLOCK_BATCH"], + run_config["BLOCK_HEADS"], + num_warps=run_config["num_warps"], + ) + + return g, beta diff --git a/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py new file mode 100644 index 0000000000..89db5e00cb --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py @@ -0,0 +1,174 @@ +import triton +import triton.language as tl +import torch +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + } +) +@triton.jit +def gated_rmsnorm_forward_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch (required, not optional) + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + Z += row * stride_z_row + group * N + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute variance (RMS norm doesn't use mean) + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + # RMS norm: compute variance directly without mean subtraction + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + # RMS norm: normalize without mean subtraction + x_hat = x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _get_gated_rmsnorm_configs(): + """Generate configurations for autotuning gated RMSNorm kernel.""" + configs = [] + # Different BLOCK_N sizes (powers of 2) + for block_n in [64, 128, 256, 512, 1024, 2048, 4096]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + # Skip configurations that are likely to be inefficient + if block_n >= 2048 and num_warps > 4: + continue + if block_n <= 128 and num_warps > 2: + continue + configs.append({"BLOCK_N": block_n, "num_warps": num_warps}) + return configs + + +def _get_gated_rmsnorm_static_key(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + M, N = x.shape + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(weight.dtype), + "N": N, + "has_bias": bias is not None, + } + + +@autotune( + kernel_name="gated_rmsnorm_forward:v1", + configs_gen_func=_get_gated_rmsnorm_configs, + static_key_func=_get_gated_rmsnorm_static_key, + run_key_func=lambda x: x.shape[0], +) +def gated_rmsnorm_forward( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + run_config: dict = None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + # z is required for gated_rmsnorm + assert z is not None, "z cannot be None for gated_rmsnorm_forward" + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + # For RMS norm, we still need rstd for the kernel + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + run_config = {"BLOCK_N": BLOCK_N, "num_warps": num_warps} + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + + # Validate BLOCK_N against group_size + if group_size > BLOCK_N: + # Fall back to largest valid BLOCK_N + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + grid = (M, ngroups) + gated_rmsnorm_forward_kernel[grid]( + x, + out, + weight, + bias, + z, + rstd, + x.stride(0), + out.stride(0), + z.stride(0), + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + num_warps=num_warps, + ) + return out diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py index 8428e52996..b43f8f95af 100644 --- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py +++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py @@ -195,7 +195,8 @@ def flash_attention_v3_fwd( False, window_size[0], window_size[1], - 0.0, + 0, # attention_chunk + 0.0, # softcap is_rotary_interleaved=False, scheduler_metadata=None, num_splits=1, diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index c8a82d3239..8ff03f3e29 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -128,7 +128,18 @@ def make_argument_parser() -> argparse.ArgumentParser: parser.add_argument( "--tool_call_parser", type=str, - choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "deepseekv32", "glm47", "kimi_k2"], + choices=[ + "qwen25", + "llama3", + "mistral", + "deepseekv3", + "qwen", + "deepseekv31", + "deepseekv32", + "glm47", + "kimi_k2", + "qwen3_coder", + ], default=None, help="tool call parser type", ) @@ -167,7 +178,7 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( - "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" + "--running_max_req_size", type=int, default=256, help="the max size for forward requests in the same time" ) parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes") parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node") @@ -568,7 +579,13 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument( "--mtp_mode", - choices=["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att", None], + choices=[ + "vanilla_with_att", + "eagle_with_att", + "vanilla_no_att", + "eagle_no_att", + None, + ], default=None, help="""Supported MTP modes. None: Disables MTP. @@ -638,6 +655,33 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument( + "--mamba_cache_size", + type=int, + default=None, + help="""The size of linear attn cache. If not specified, will be calculated + automatically based on mamba_cache_ratio or max_total_token_num.""", + ) + parser.add_argument( + "--mamba_cache_ratio", + type=lambda v: float(v) + if 0.0 <= (_ := float(v)) <= 1.0 + else (_ for _ in ()).throw( + argparse.ArgumentTypeError(f"--mamba_cache_ratio must be between 0.0 and 1.0, got {v}") + ), + default=0.5, + help="""Ratio of mamba cache to total cache memory (mamba + KV). + Only effective when both mamba_cache_size and max_total_token_num are not set. + Default is 0.5 (50%% mamba cache, 50%% KV cache). + Example: 0.3 -> 30%% mamba, 70%% KV; 0.7 -> 70%% mamba, 30%% KV.""", + ) + parser.add_argument( + "--mamba_ssm_data_type", + type=str, + choices=["bfloat16", "float32"], + default="float32", + help="the data type of the model weight", + ) parser.add_argument( "--hardware_platform", type=str, diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index 598bd8f1f2..c4cdfa4ba3 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -193,6 +193,13 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req multimodal_params_dict["images"].append({"type": "base64", "data": data}) else: raise ValueError("Unrecognized image input.") + elif img.startswith("file://"): + # Local file path with file:// prefix + file_path = img[7:] # Remove "file://" prefix + with open(file_path, "rb") as f: + multimodal_params_dict["images"].append( + {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")} + ) else: raise ValueError( "Unrecognized image input. Supports local path, http url, base64, and PIL.Image." diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 7f16d519a2..a38008af6f 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -46,6 +46,7 @@ async def build_prompt(request, tools) -> str: global tokenizer # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] + kwargs = {"conversation": messages} if request.character_settings: kwargs["character_settings"] = request.character_settings @@ -57,15 +58,7 @@ async def build_prompt(request, tools) -> str: try: input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools) - except: - # This except branch will be triggered when the chosen model - # has a different tools input format that is not compatiable - # with openAI's apply_chat_template tool_call format, like Mistral. - tools = [t if "function" in t else {"function": t} for t in tools] - input_str = tokenizer.apply_chat_template( - **kwargs, - tokenize=True, - add_generation_prompt=True, - tools=tools, - ) + except BaseException as e: + logger.error(f"Failed to build prompt: {e}") + raise e return input_str diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py index 49b21c38fc..aa71558146 100644 --- a/lightllm/server/core/objs/sampling_params.py +++ b/lightllm/server/core/objs/sampling_params.py @@ -334,15 +334,31 @@ class SamplingParams(ctypes.Structure): def init(self, tokenizer, **kwargs): super().__init__() + # 移除kwargs中为null的参数,避免覆盖默认值 + kwargs = {k: v for k, v in kwargs.items() if v is not None} + self.best_of = kwargs.get("best_of", 1) self.n = kwargs.get("n", self.best_of) - self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample) - self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) - self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) - self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) - self.temperature = kwargs.get("temperature", SamplingParams._temperature) - self.top_p = kwargs.get("top_p", SamplingParams._top_p) - self.top_k = kwargs.get("top_k", SamplingParams._top_k) + do_sample = kwargs.get("do_sample", SamplingParams._do_sample) + self.do_sample = False if do_sample is None else do_sample + + presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty) + self.presence_penalty = 0.0 if presence_penalty is None else presence_penalty + + frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty) + self.frequency_penalty = 0.0 if frequency_penalty is None else frequency_penalty + + repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty) + self.repetition_penalty = 1.0 if repetition_penalty is None else repetition_penalty + + temperature = kwargs.get("temperature", SamplingParams._temperature) + self.temperature = 1.0 if temperature is None else temperature + + top_p = kwargs.get("top_p", SamplingParams._top_p) + self.top_p = 1.0 if top_p is None else top_p + + top_k = kwargs.get("top_k", SamplingParams._top_k) + self.top_k = -1 if top_k is None else top_k self.ignore_eos = kwargs.get("ignore_eos", False) self.image_max_patch_num = kwargs.get("image_max_patch_num", -1) self.max_new_tokens = kwargs.get("max_new_tokens", 16384) @@ -410,13 +426,35 @@ def init(self, tokenizer, **kwargs): def load_generation_cfg(cls, weight_dir): try: generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict() + # Some checkpoints store null sampling fields in generation_config.json. + # Keep robust numeric defaults instead of propagating None into ctypes fields. cls._do_sample = generation_cfg.get("do_sample", False) + if cls._do_sample is None: + cls._do_sample = False + cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0) + if cls._presence_penalty is None: + cls._presence_penalty = 0.0 + cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0) + if cls._frequency_penalty is None: + cls._frequency_penalty = 0.0 + cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0) + if cls._repetition_penalty is None: + cls._repetition_penalty = 1.0 + cls._temperature = generation_cfg.get("temperature", 1.0) + if cls._temperature is None: + cls._temperature = 1.0 + cls._top_p = generation_cfg.get("top_p", 1.0) + if cls._top_p is None: + cls._top_p = 1.0 + cls._top_k = generation_cfg.get("top_k", -1) + if cls._top_k is None: + cls._top_k = -1 except: pass diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d3dc849664..a3f30a5aaa 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -31,7 +31,8 @@ class StartArgs: batch_max_tokens: Optional[int] = field(default=None) eos_id: List[int] = field(default_factory=list) tool_call_parser: Optional[str] = field( - default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]} + default=None, + metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]}, ) reasoning_parser: Optional[str] = field( default=None, @@ -54,7 +55,7 @@ class StartArgs: }, ) chat_template: Optional[str] = field(default=None) - running_max_req_size: int = field(default=1000) + running_max_req_size: int = field(default=512) tp: int = field(default=1) dp: int = field(default=1) nnodes: int = field(default=1) @@ -108,7 +109,7 @@ class StartArgs: disable_cudagraph: bool = field(default=False) enable_prefill_cudagraph: bool = field(default=False) prefll_cudagraph_max_handle_token: int = field(default=512) - graph_max_batch_size: int = field(default=256) + graph_max_batch_size: int = field(default=512) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) graph_max_len_in_batch: int = field(default=0) @@ -135,7 +136,18 @@ class StartArgs: ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) mtp_mode: Optional[str] = field( - default=None, metadata={"choices": ["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att"]} + default=None, + metadata={ + "choices": [ + "vanilla_with_att", + "eagle_with_att", + "vanilla_no_att", + "eagle_no_att", + "qwen3next_vanilla", + "qwen3next_eagle", + None, + ] + }, ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) @@ -160,3 +172,8 @@ class StartArgs: metric_port: int = field(default=None) multinode_httpmanager_port: int = field(default=12345) multi_level_kv_cache_port: int = field(default=None) + + # hybrid attention model (Qwen3Next) + mamba_cache_size: Optional[int] = field(default=None) + mamba_cache_ratio: Optional[float] = field(default=0.5) + mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py index 3a8fddf744..13aab66179 100644 --- a/lightllm/server/function_call_parser.py +++ b/lightllm/server/function_call_parser.py @@ -11,6 +11,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import ast import json import orjson import logging @@ -1717,6 +1718,228 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami return StreamingParseResult(normal_text="", calls=calls) +class Qwen3CoderDetector(BaseFormatDetector): + """ + Detector for Qwen3-Coder XML-style function call format. + + Format Structure: + ``` + + + + value1 + + + value2 + + + + ``` + + Key differences from Qwen25Detector (JSON-based): + - Parameters are XML key-value pairs, not JSON objects + - Function name is embedded in the tag attribute + - Values need schema-aware type conversion (string by default) + + Reference: https://docs.vllm.ai/projects/recipes/en/latest/Qwen/Qwen3-Coder-480B-A35B.html + """ + + def __init__(self): + super().__init__() + self.bot_token = "" + self.eot_token = "" + self.tool_call_separator = "\n" + + # Regex patterns + self.tool_call_block_regex = re.compile(r"(.*?)", re.DOTALL) + self.function_regex = re.compile(r"||(?=)|$)", re.DOTALL + ) + self._normal_text_buffer = "" + + def has_tool_call(self, text: str) -> bool: + return " Dict: + """Extract parameter type configuration from tool definitions.""" + for tool in tools: + if tool.function.name == func_name and tool.function.parameters: + params = tool.function.parameters + if isinstance(params, dict) and "properties" in params: + return params["properties"] + elif isinstance(params, dict): + return params + return {} + + def _convert_param_value(self, value: str, param_name: str, param_config: Dict, func_name: str) -> Any: + """Convert parameter value based on schema type. Safe alternative to eval().""" + if value.lower() == "null": + return None + + if param_name not in param_config: + return value + + prop = param_config.get(param_name, {}) + param_type = str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string" + + if param_type in ("string", "str", "enum"): + return value + elif param_type.startswith("int") or param_type == "integer": + try: + return int(value) + except (ValueError, TypeError): + return value + elif param_type in ("number", "float", "double"): + try: + fv = float(value) + return int(fv) if fv == int(fv) else fv + except (ValueError, TypeError): + return value + elif param_type in ("boolean", "bool"): + return value.lower() == "true" + elif param_type in ("object", "array"): + try: + return json.loads(value) + except (json.JSONDecodeError, TypeError, ValueError): + try: + return ast.literal_eval(value) + except (ValueError, SyntaxError, TypeError): + return value + return value + + def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional[ToolCallItem]: + """Parse a single ... block into a ToolCallItem.""" + try: + end_index = function_str.index(">") + except ValueError: + return None + + func_name = function_str[:end_index].strip() + tool_indices = self._get_tool_indices(tools) + if func_name not in tool_indices: + logger.warning(f"Model attempted to call undefined function: {func_name}") + return None + + parameters_text = function_str[end_index + 1 :] + param_config = self._get_param_config(func_name, tools) + param_dict = {} + + for match in self.parameter_regex.findall(parameters_text): + try: + idx = match.index(">") + except ValueError: + continue + param_name = match[:idx].strip() + param_value = match[idx + 1 :] + # Strip leading/trailing newlines from value + if param_value.startswith("\n"): + param_value = param_value[1:] + if param_value.endswith("\n"): + param_value = param_value[:-1] + + param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name) + + return ToolCallItem( + tool_index=tool_indices[func_name], + name=func_name, + parameters=json.dumps(param_dict, ensure_ascii=False), + ) + + def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult: + idx = text.find(self.bot_token) + normal_text = text[:idx].strip() if idx != -1 else text + + if " StreamingParseResult: + """Streaming incremental parsing for Qwen3-Coder XML tool calls.""" + self._buffer += new_text + current_text = self._buffer + + if not self.has_tool_call(current_text): + partial_len = self._ends_with_partial_token(current_text, self.bot_token) + if partial_len: + return StreamingParseResult() + self._buffer = "" + cleaned = new_text.replace(self.eot_token, "") + return StreamingParseResult(normal_text=cleaned) + + # Check for complete tool call blocks + if self.eot_token in current_text: + result = self.detect_and_parse(current_text, tools) + last_end = current_text.rfind(self.eot_token) + if last_end != -1: + self._buffer = current_text[last_end + len(self.eot_token) :].lstrip() + else: + self._buffer = "" + self.current_tool_id = -1 + self.current_tool_name_sent = False + return result + + # Partial tool call - try to extract function name for early streaming + if not hasattr(self, "_tool_indices"): + self._tool_indices = self._get_tool_indices(tools) + + calls = [] + tool_call_start = current_text.find(self.bot_token) + if tool_call_start == -1: + return StreamingParseResult() + + content_after = current_text[tool_call_start + len(self.bot_token) :] + func_prefix = "") + if gt_pos == -1: + return StreamingParseResult() + + func_name = after_func[:gt_pos].strip() + + if self.current_tool_id == -1: + self.current_tool_id = 0 + self.prev_tool_call_arr = [] + self.streamed_args_for_tool = [""] + + while len(self.prev_tool_call_arr) <= self.current_tool_id: + self.prev_tool_call_arr.append({}) + while len(self.streamed_args_for_tool) <= self.current_tool_id: + self.streamed_args_for_tool.append("") + + if func_name and func_name in self._tool_indices and not self.current_tool_name_sent: + calls.append( + ToolCallItem( + tool_index=self.current_tool_id, + name=func_name, + parameters="", + ) + ) + self.current_tool_name_sent = True + self.prev_tool_call_arr[self.current_tool_id] = {"name": func_name, "arguments": {}} + + return StreamingParseResult(normal_text="", calls=calls) + + class FunctionCallParser: """ Parser for function/tool calls in model outputs. @@ -1736,6 +1959,7 @@ class FunctionCallParser: "mistral": MistralDetector, "qwen": Qwen25Detector, "qwen25": Qwen25Detector, + "qwen3_coder": Qwen3CoderDetector, } def __init__(self, tools: List[Tool], tool_call_parser: str): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e28e4c93ad..2cbbb7064a 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -296,6 +296,14 @@ async def generate( if self.pd_mode.is_P_or_NORMAL(): await multimodal_params.verify_and_preload(request) + # Debug logging for multimodal requests + if multimodal_params and multimodal_params.images: + logger.debug( + f"[MULTIMODAL_DEBUG] req_id={group_request_id}, " + f"num_images={len(multimodal_params.images)}, " + f"max_new_tokens={sampling_params.max_new_tokens}" + ) + # 记录请求到达的相关信息 await self._log_req_header(request_headers, group_request_id) # encode diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py new file mode 100644 index 0000000000..08f6ba3fff --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -0,0 +1,173 @@ +from typing import Set, Protocol, List, Optional, Tuple + +import torch +from sortedcontainers import SortedSet + +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class HybridRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager): + super().__init__(unique_name, total_token_num, rank_in_node, kv_cache_mem_manager) + assert hasattr(kv_cache_mem_manager, "mamba_cache_mem_manager") + self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.mamba_cache_mem_manager + self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,)) + + def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): + if need_buffer_num > self.buffer_mem_manager.can_use_mem_size: + need_evict_buffer_num = need_buffer_num - self.buffer_mem_manager.can_use_mem_size + + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self._evict_buffer(need_evict_buffer_num, release_buffer, release_mem) + self.buffer_mem_manager.free(release_buffers) + if len(release_mems) > 0: + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return + + def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback): + while need_evict_buffer_num > 0: + node = self.evict_buffer_set.pop(0) + assert node.buffer_idx is not None + evict_buffer_callback(node.buffer_idx) + node.buffer_idx = None + need_evict_buffer_num -= 1 + # 当一个节点的buffer_idx变为None时,事实上无法在后续进行match, + # 但当该节点子节点或者引用数不为0时,仍然需要保留, 否则则应该被删除 + if node.is_leaf() and node.ref_counter == 0: + self.evict_tree_set.discard(node) + evict_token_callback(node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + return + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + miss_prefix_len = 0 + evict_token_list = [] + kv_len = tree_node.node_prefix_total_len + while tree_node != self.root_node and tree_node.buffer_idx is None: + if tree_node.is_leaf(): + self.evict_tree_set.discard(tree_node) + + # Only update ref_counter when update_refs is True to maintain consistency + # with _match_prefix_helper which only increments ref_counter when update_refs=True + if update_refs: + if tree_node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) + tree_node.ref_counter -= 1 # 只减少当前节点,不递归 + + if tree_node.is_leaf() and tree_node.ref_counter == 0: + evict_token_list.append(tree_node.token_mem_index_value) + self.tree_total_tokens_num.arr[0] -= len(tree_node.token_mem_index_value) + parent_node: TreeNode = tree_node.parent + parent_node.remove_child(tree_node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + tree_node = parent_node + else: + if tree_node.is_leaf(): + self.evict_tree_set.add(tree_node) + tree_node = tree_node.parent + miss_prefix_len += len(ans_value_list.pop()) + + if len(evict_token_list) > 0: + evict_token_value = torch.concat(evict_token_list) + self.mem_manager.free(evict_token_value) + + if tree_node == self.root_node: + return None, kv_len - miss_prefix_len, None + + update_node = tree_node + while update_node != self.root_node: + if update_node.buffer_idx is not None: + self.evict_buffer_set.discard(update_node) + update_node.update_buffer_time() + self.evict_buffer_set.add(update_node) + update_node = update_node.parent + + value = torch.concat(ans_value_list) + return tree_node, miss_prefix_len, value + + def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int): + """Set buffer_idx for a node and add it to evict_buffer_set.""" + self.evict_buffer_set.discard(node) + if node.is_leaf(): + self.evict_tree_set.discard(node) + if node.buffer_idx is not None: + self.buffer_mem_manager.free([node.buffer_idx]) + node.buffer_idx = buffer_idx + node.update_buffer_time() + self.evict_buffer_set.add(node) + if node.is_leaf(): + self.evict_tree_set.add(node) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self.evict(need_evict_token_num, release_buffer, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + if len(release_buffers) > 0: + self.buffer_mem_manager.free(release_buffers) + return + + def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node + ), f"error evict tree node state: {node.ref_counter}, {len(node.children)}" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + if node.buffer_idx is not None: + self.evict_buffer_set.discard(node) + evict_buffer_callback(node.buffer_idx) + node.buffer_idx = None + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 88b099459b..9f8d78c491 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -31,6 +31,12 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 + # Used by hybrid attention models (e.g., Qwen3Next) to track + # a per-request buffer_idx alongside the token-level KV cache. + # Pure attention models keep buffer_idx as None. + self.buffer_idx = None + self.buffer_time = time_gen.generate_time_id() + def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) @@ -78,6 +84,9 @@ def remove_child(self, child_node: "TreeNode"): def update_time(self): self.time_id = time_gen.generate_time_id() + def update_buffer_time(self): + self.buffer_time = time_gen.generate_time_id() + def is_leaf(self): return len(self.children) == 0 @@ -103,10 +112,10 @@ class RadixCache: unique_name 主要用于解决单机,多实列部署时的shm冲突 """ - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager=None): from lightllm.common.kv_cache_mem_manager import MemoryManager - self.mem_manager: MemoryManager = mem_manager + self.mem_manager: MemoryManager = kv_cache_mem_manager self._key_dtype = torch.int64 self._value_dtype = torch.int64 @@ -359,6 +368,7 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: or parent_node.ref_counter != 0 or len(parent_node.children) != 1 or child_node.ref_counter != 0 + or parent_node.buffer_idx is not None ): return None @@ -489,7 +499,7 @@ def _print_helper(self, node: TreeNode, indent): " " * indent, f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ - node_value_len: {node.node_value_len}", + node_value_len: {node.node_value_len} buffer_idx: {node.buffer_idx}", ) for _, child in node.children.items(): self._print_helper(child, indent=indent + 2) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 0a83b101be..731fbea405 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -7,10 +7,11 @@ from dataclasses import dataclass, field from typing import List, Dict, Tuple, Optional, Callable, Any -from lightllm.common.req_manager import ReqManager +from lightllm.common.req_manager import ReqManager, ReqManagerForMamba from lightllm.utils.infer_utils import mark_start, mark_end from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from lightllm.utils.log_utils import init_logger from lightllm.server.req_id_generator import convert_sub_id_to_group_id from lightllm.common.basemodel.infer_lock import g_infer_state_lock @@ -22,6 +23,9 @@ logger = init_logger(__name__) +# Cache for mtp_range tensors to avoid repeated allocation +_mtp_range_cache: Dict[int, torch.Tensor] = {} + @dataclass class InferenceContext: @@ -32,10 +36,15 @@ class InferenceContext: infer_req_ids = None vocab_size = None cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None + mtp_step: int = 0 overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream + @property + def has_recurrent_state(self): + return self.req_manager is not None and self.req_manager.has_recurrent_state + def register( self, backend, @@ -57,6 +66,9 @@ def register( self.infer_req_ids = [] self.vocab_size = vocab_size + + self.mtp_step = get_env_start_args().mtp_step + return def init_cpu_embed_cache_client(self): @@ -73,6 +85,31 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream: self.cpu_kv_cache_stream = torch.cuda.Stream() return self.cpu_kv_cache_stream + def _alloc_and_copy_req_buffers( + self, req_manager: ReqManagerForMamba, radix_cache: HybridRadixCache, req_objs: List["InferReq"] + ) -> None: + if not req_objs: + return + + if radix_cache is not None: + radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs) * (self.mtp_step + 1)) + + req_idx_gpu = torch.tensor([r.req_idx for r in req_objs], device="cuda", dtype=torch.int64) + req_manager.alloc_buffer_for_req(req_idx_gpu) + + if radix_cache is not None: + fork_req_ids = [r.req_idx for r in req_objs if r.shared_kv_node is not None] + if fork_req_ids: + src_buf_ids = [r.shared_kv_node.buffer_idx for r in req_objs if r.shared_kv_node is not None] + req_tensor = torch.tensor(fork_req_ids, device="cuda", dtype=torch.int32) + src_tensor = torch.tensor(src_buf_ids, device="cuda", dtype=torch.int32) + + mtp_step = req_manager.mtp_step + if mtp_step not in _mtp_range_cache: + _mtp_range_cache[mtp_step] = torch.arange(0, mtp_step + 1, dtype=torch.int32, device="cuda") + dst_buffers = req_manager.req_to_buffer_index[req_tensor[:, None], _mtp_range_cache[mtp_step][None, :]] + req_manager.buffer_mem_manager.fork_state_buffers(src_tensor, dst_buffers) + def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]: req_objs = [] request_ids = [] @@ -111,6 +148,9 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: slave_req: InferReq = slave_req slave_req.related_master_req = master_req + if isinstance(self.req_manager, ReqManagerForMamba): + self._alloc_and_copy_req_buffers(self.req_manager, self.radix_cache, req_objs) + return req_objs def free_a_req_mem(self, free_token_index: List, req: "InferReq"): @@ -122,7 +162,8 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): # .cpu() 是 流内阻塞操作 value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len, _ = self.radix_cache.insert(key, value) + prefix_len, node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -130,6 +171,42 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): self.radix_cache.dec_node_ref_counter(req.shared_kv_node) req.shared_kv_node = None + def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> bool: + if self.radix_cache is None: + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len]) + else: + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() + + prefix_len, node = self.radix_cache.insert(key, value) + old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len + free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) + if req.shared_kv_node is not None: + assert req.shared_kv_node.node_prefix_total_len <= prefix_len + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + req.shared_kv_node = None + + if node.buffer_idx is None: + req_to_buffer_index = self.req_manager.req_to_buffer_index + buffer_idx = req_to_buffer_index[req.req_idx, 0].item() + self.radix_cache.add_buffer_idx_to_node(node, buffer_idx) + # 该请求的 buffer 已经被插入到 radix cache 中,不需要手动释放 + return False + return True + + def _free_req_mem_and_buffers(self, free_token_index: List, free_buffer_index: List, req: "InferReq"): + """释放请求的 KV cache 和 buffer 内存""" + if self.has_recurrent_state: + need_free_base_buffer = self.free_a_req_mem_for_mamba(free_token_index, req) + req_to_buffer_index = self.req_manager.req_to_buffer_index + if need_free_base_buffer: + free_buffer_index.extend(req_to_buffer_index[req.req_idx, :].tolist()) + elif self.mtp_step > 0: + free_buffer_index.extend(req_to_buffer_index[req.req_idx, 1:].tolist()) + else: + self.free_a_req_mem(free_token_index, req) + def _save_promptcache_kvbuffer(self): """ save prompt cache kv buffer @@ -151,19 +228,23 @@ def _filter(self, finished_request_ids: List[int]): free_req_index = [] free_token_index = [] + free_buffer_index = [] for request_id in finished_request_ids: req: InferReq = self.requests_mapping.pop(request_id) if self.args.diverse_mode: req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) - + self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) free_req_index.append(req.req_idx) # logger.info(f"infer release req id {req.shm_req.request_id}") req.shm_req.shm_infer_released = True self.shm_req_manager.put_back_req_obj(req.shm_req) - free_token_index = custom_cat(free_token_index) - self.req_manager.free(free_req_index, free_token_index) + if len(free_token_index) != 0: + free_token_index = custom_cat(free_token_index) + self.req_manager.free(free_req_index, free_token_index) + + if len(free_buffer_index) != 0 and isinstance(self.req_manager, ReqManagerForMamba): + self.req_manager.free_buffer(free_buffer_index) finished_req_ids_set = set(finished_request_ids) self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set] @@ -191,12 +272,15 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if pause_reqs: g_infer_state_lock.acquire() + pause_req_indices = [] free_token_index = [] + free_buffer_index = [] for req in pause_reqs: + pause_req_indices.append(req.req_idx) if self.args.diverse_mode: # 发生暂停的时候,需要清除 diverse 模式下的主从关系 req.clear_master_slave_state() - self.free_a_req_mem(free_token_index, req) + self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req) req.cur_kv_len = 0 req.shm_req.shm_cur_kv_len = req.cur_kv_len assert req.wait_pause is True @@ -210,13 +294,16 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) + if len(free_buffer_index) != 0 and isinstance(self.req_manager, ReqManagerForMamba): + self.req_manager.free_buffer(free_buffer_index) + g_infer_state_lock.release() return self def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int): if paused_reqs: g_infer_state_lock.acquire() - + revovered_reqs = [] for req in paused_reqs: prefill_need_token_num = req.get_cur_total_len() if prefill_need_token_num > can_alloc_token_num: @@ -228,7 +315,9 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo req.shm_req.is_paused = False logger.debug(f"infer recover paused req id {req.req_id}") can_alloc_token_num -= prefill_need_token_num + revovered_reqs.append(req) + self._alloc_and_copy_req_buffers(revovered_reqs) g_infer_state_lock.release() return @@ -466,8 +555,8 @@ def get_input_token_ids(self): return self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()] def get_chuncked_input_token_ids(self): - chunked_start = self.cur_kv_len - chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size) + # 复用 get_chuncked_input_token_len 的逻辑,保持一致性 + chunked_end = self.get_chuncked_input_token_len() return self.shm_req.shm_prompt_ids.arr[0:chunked_end] def get_chuncked_input_token_len(self): diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 8b085c45ed..08932e4e41 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -9,7 +9,6 @@ from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.log_utils import init_logger from lightllm.models import get_model -from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -172,12 +171,14 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + + radix_cache_class = self.model.get_radix_class() self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, - mem_manager=self.model.mem_manager, + kv_cache_mem_manager=self.model.mem_manager, ) if self.use_dynamic_prompt_cache else None @@ -287,9 +288,8 @@ def decode(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]): raise NotImplementedError() def init_mtp_draft_model(self, main_kvargs: dict): - # 当前只支持 deepseekv3 模式的 mtp self.mtp_step = self.args.mtp_step - self.draft_models: List[Deepseek3MTPModel] = [] + self.draft_models = [] os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" @@ -302,6 +302,7 @@ def init_mtp_draft_model(self, main_kvargs: dict): for i in range(num_mtp_modules): mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) + model_type = mtp_model_cfg.get("model_type", "") mtp_model_kvargs = { "weight_dir": self.args.mtp_draft_model_dir[i], "max_total_token_num": self.model.mem_manager.size, @@ -324,21 +325,22 @@ def init_mtp_draft_model(self, main_kvargs: dict): "mtp_previous_draft_models": self.draft_models.copy(), } - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i]) - if mtp_model_cfg["model_type"] == "deepseek_v3": + # Select MTP model class based on model type + model_type = mtp_model_cfg.get("model_type", "") + if model_type == "deepseek_v3": assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "qwen3_moe": + elif model_type == "qwen3_moe": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs)) - elif mtp_model_cfg["model_type"] == "mistral": + elif model_type == "mistral": assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"] self.draft_models.append(MistralMTPModel(mtp_model_kvargs)) elif mtp_model_cfg["model_type"] == "glm4_moe_lite": assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"] self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs)) else: - assert False, f"error mtp mode {mtp_model_cfg['model_type']}" + raise ValueError(f"Unsupported MTP model type: {model_type}") self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}") return diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 2800bf0f6b..25726b2578 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -112,6 +112,14 @@ def get_tokenizer( tokenizer = QWen3VLTokenizer( tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg ) + elif model_type in ["qwen3_5", "qwen3_5_moe"] and "vision_config" in model_cfg: + from transformers import AutoProcessor + from ..models.qwen3_5.model import QWen3_5Tokenizer + + processor = AutoProcessor.from_pretrained(tokenizer_name) + tokenizer = QWen3_5Tokenizer( + tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg + ) elif model_cfg.get("thinker_config") is not None: from transformers import AutoProcessor diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 3e97f4de3e..ed4665e725 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -68,7 +68,7 @@ def exposed_init_model(self, kvargs): self.model = ( Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) - elif self.model_type in ["qwen3_vl", "qwen3_vl_moe"]: + elif self.model_type in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]: self.model = ( Qwen3VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16() ) diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index 54f20384b7..a4fbc594bc 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -87,6 +87,22 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: except: pass + # Qwen3.5 checkpoints can have an eos_token_id in config that differs from + # tokenizer.eos_token_id. In practice tokenizer.eos_token_id is the reliable + # stop id (<|im_end|>) for detokenization/stop behavior. + try: + config_json = get_config_json(model_path) + model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type") + if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}: + from transformers import AutoTokenizer + + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) + if tokenizer.eos_token_id is not None: + return [int(tokenizer.eos_token_id)] + except Exception: + # Fall back to config-based lookup below. + pass + eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) if isinstance(eos_token_id, int): return [eos_token_id] @@ -186,6 +202,8 @@ def has_vision_module(model_path: str) -> bool: ): # Qwen3OmniMoeVisionTransformerPretrainedModel return True + elif model_type in ["qwen3_5", "qwen3_5_moe"]: + return True else: raise Exception("unknown vision model type") except: diff --git a/test_gsmk.py b/test_gsmk.py new file mode 100644 index 0000000000..78a5aa467f --- /dev/null +++ b/test_gsmk.py @@ -0,0 +1,241 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + "repetition_penalty": 1.0, + "top_p": 1.0, + "top_k": 1, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + "Invalid API response format. 'generated_text' should be a non-empty list, " + f"got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + # First try to find the answer after "####" marker (GSM8K format) + match = re.search(r"####\s*(-?\d+)", answer_str) + if match: + try: + return ast.literal_eval(match.group(1)) + except SyntaxError: + pass + # Fallback: find all numbers and take the last one + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=256) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. " + "Using {} questions.".format(num_questions, max_available, num_shots, max_available) + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(label != INVALID for label in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=few_shot_examples + questions[i], + temperature=0, + max_tokens=1024, + stop=["Question", "Assistant:", "<|separator|>", "Human:", "\n\nQuestion"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args)