diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py
index 5c1d2b8712..2463bbb8e8 100755
--- a/lightllm/common/basemodel/basemodel.py
+++ b/lightllm/common/basemodel/basemodel.py
@@ -11,6 +11,7 @@
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
from lightllm.common.basemodel.infer_struct import InferStateInfo
+from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
from lightllm.common.kv_cache_mem_manager import MemoryManager
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
from lightllm.common.req_manager import ReqManager
@@ -53,6 +54,9 @@ class TpPartBaseModel:
# infer state class
infer_state_class = InferStateInfo
+ def get_radix_class(self):
+ return RadixCache
+
def __init__(self, kvargs):
self.args = get_env_start_args()
self.run_mode = kvargs["run_mode"]
diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py
index 9153349c5d..646f998642 100755
--- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py
+++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py
@@ -62,20 +62,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor
def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
raise Exception("need to impl")
- def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
- input1 = self._att_norm(input_embdings, infer_state, layer_weight)
- q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
- input1 = None
+ def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
+ q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
-
o = self._context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)
-
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ return o
+
+ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
+ input1 = self._att_norm(input_embdings, infer_state, layer_weight)
+ o = self.context_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
@@ -87,39 +88,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings
- def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
- input1 = self._att_norm(input_embdings, infer_state, layer_weight)
- q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
- input1 = None
+ def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
+ q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ return o
+
+ def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
+ input1 = self._att_norm(input_embdings, infer_state, layer_weight)
+ o = self.token_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
- input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings
- def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
- input1 = self._att_norm(input_embdings, infer_state, layer_weight)
- q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
- input1 = None
+ def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
+ q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
-
o = self._context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)
-
q = None
o = self._tpsp_get_o(o, infer_state, layer_weight)
+ return o
+
+ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
+ input1 = self._att_norm(input_embdings, infer_state, layer_weight)
+ o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
@@ -129,14 +133,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings
- def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
- input1 = self._att_norm(input_embdings, infer_state, layer_weight)
- q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
- input1 = None
+ def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
+ q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._tpsp_get_o(o, infer_state, layer_weight)
+ return o
+
+ def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
+ input1 = self._att_norm(input_embdings, infer_state, layer_weight)
+ o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py
index edf7fe21b9..21b5b7959e 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py
@@ -7,7 +7,16 @@
QKVROWNMMWeight,
COLMMWeight,
)
-from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight
+from .norm_weight import (
+ TpRMSNormWeight,
+ RMSNormWeight,
+ GEMMANormWeight,
+ LayerNormWeight,
+ NoTpGEMMANormWeight,
+ QKRMSNORMWeight,
+ QKGEMMANormWeight,
+)
from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight
from .att_sink_weight import TpAttSinkWeight
from .fused_moe.fused_moe_weight import FusedMoeWeight
+from .parameter_weight import ParameterWeight, TpParameterWeight
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py
index 20416606f9..89a3d24119 100644
--- a/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py
@@ -71,6 +71,14 @@ def __call__(
return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func)
+class GEMMANormWeight(RMSNormWeight):
+ def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
+ if self.weight_name in weights:
+ self.weight.copy_(weights[self.weight_name])
+ self.weight += 1
+ self.weight.load_ok = True
+
+
class LayerNormWeight(BaseWeightTpl, PlatformAwareOp):
def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None):
super().__init__(tp_rank=0, tp_world_size=1)
@@ -276,3 +284,23 @@ def __call__(
eps: float,
) -> None:
return self._forward(q=q, k=k, eps=eps)
+
+
+class QKGEMMANormWeight(QKRMSNORMWeight):
+ def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
+ if self.q_weight_name in weights:
+ self.q_weight.copy_(weights[self.q_weight_name])
+ self.q_weight += 1
+ self.q_weight.load_ok = True
+ if self.k_weight_name in weights:
+ self.k_weight.copy_(weights[self.k_weight_name])
+ self.k_weight += 1
+ self.k_weight.load_ok = True
+
+ def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple:
+ assert q.ndim == 2 and self.q_weight.ndim == 1
+ assert k.ndim == 2 and self.k_weight.ndim == 1
+ # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
+ # See https://github.com/huggingface/transformers/pull/29402
+ # So we need to set fp32_multiply to True here.
+ return qk_rmsnorm_fused_forward(q=q, k=k, w_q=self.q_weight, w_k=self.k_weight, eps=eps, fp32_multiply=True)
diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py
new file mode 100644
index 0000000000..0afb0ecab2
--- /dev/null
+++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py
@@ -0,0 +1,83 @@
+import torch
+from typing import Dict, Optional, Tuple
+from .base_weight import BaseWeightTpl
+
+
+class ParameterWeight(BaseWeightTpl):
+ def __init__(
+ self,
+ weight_name: str,
+ data_type: torch.dtype,
+ weight_shape: Optional[Tuple[int, ...]],
+ bias_name: Optional[str] = None,
+ bias_shape: Optional[Tuple[int, ...]] = None,
+ ):
+ super().__init__()
+ self.weight_name = weight_name
+ self.bias_name = bias_name
+ self.data_type_ = data_type
+ self.weight_shape = weight_shape
+ self.bias_shape = bias_shape
+ self.weight: Optional[torch.Tensor] = None
+ self.bias: Optional[torch.Tensor] = None
+ if weight_shape is not None:
+ self._create_weight()
+
+ def _create_weight(self):
+ if self.weight_shape is not None:
+ self.weight = torch.empty(*self.weight_shape, dtype=self.data_type_, device=self.device_id_)
+ self.weight.load_ok = False
+ if self.bias_name is not None and self.bias_shape is not None:
+ self.bias = torch.empty(*self.bias_shape, dtype=self.data_type_, device=self.device_id_)
+ self.bias.load_ok = False
+
+ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
+ if self.weight_name in weights:
+ t_weight = weights[self.weight_name]
+ self.weight.copy_(t_weight.to(self.data_type_))
+ self.weight.load_ok = True
+ if self.bias_name is not None and self.bias_name in weights:
+ t_bias = weights[self.bias_name]
+ self.bias.copy_(t_bias.to(self.data_type_))
+ self.bias.load_ok = True
+
+ def verify_load(self) -> bool:
+ if self.weight is not None and not getattr(self.weight, "load_ok", False):
+ return False
+ if self.bias is not None and not getattr(self.bias, "load_ok", False):
+ return False
+ return True
+
+
+class TpParameterWeight(ParameterWeight):
+ def __init__(
+ self,
+ weight_name: str,
+ data_type: torch.dtype,
+ split_n_embed: int,
+ bias_name: Optional[str] = None,
+ weight_shape: Optional[Tuple[int, ...]] = None,
+ bias_shape: Optional[Tuple[int, ...]] = None,
+ ):
+ self.split_n_embed = split_n_embed
+ # Calculate TP-split shapes if full shapes are provided
+ tp_weight_shape = None
+ tp_bias_shape = None
+ if weight_shape is not None:
+ tp_weight_shape = (split_n_embed,) + weight_shape[1:]
+ if bias_shape is not None:
+ tp_bias_shape = (split_n_embed,) + bias_shape[1:]
+ super().__init__(weight_name, data_type, tp_weight_shape, bias_name, tp_bias_shape)
+
+ def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
+ start = self.split_n_embed * self.tp_rank_
+ end = self.split_n_embed * (self.tp_rank_ + 1)
+
+ if self.weight_name in weights:
+ t_weight = weights[self.weight_name][start:end]
+ self.weight.copy_(t_weight.to(self.data_type_))
+ self.weight.load_ok = True
+ if self.bias_name is not None and self.bias_name in weights:
+ t_bias = weights[self.bias_name][start:end]
+ self.bias.copy_(t_bias.to(self.data_type_))
+ self.bias.load_ok = True
diff --git a/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py
new file mode 100644
index 0000000000..bd2aaed530
--- /dev/null
+++ b/lightllm/common/basemodel/triton_kernel/mamba_buffer_copy.py
@@ -0,0 +1,247 @@
+import torch
+import triton
+import triton.language as tl
+from lightllm.common.triton_utils.autotuner import autotune
+
+
+@triton.jit
+def _copy_buffer_kernel(
+ src_ptr,
+ dst_ptr,
+ src_idx_ptr,
+ dst_idx_ptr,
+ stride_layer,
+ stride_slot,
+ d_size,
+ BLOCK_D: tl.constexpr,
+):
+ pair_idx = tl.program_id(0)
+ layer_idx = tl.program_id(1)
+ block_d = tl.program_id(2)
+
+ stride_layer = stride_layer.to(tl.int64)
+ stride_slot = stride_slot.to(tl.int64)
+
+ src_slot = tl.load(src_idx_ptr + pair_idx).to(tl.int64)
+ dst_slot = tl.load(dst_idx_ptr + pair_idx).to(tl.int64)
+
+ offs = block_d * BLOCK_D + tl.arange(0, BLOCK_D)
+ mask = offs < d_size
+
+ base = layer_idx * stride_layer
+ tl.store(
+ dst_ptr + base + dst_slot * stride_slot + offs,
+ tl.load(src_ptr + base + src_slot * stride_slot + offs, mask=mask),
+ mask=mask,
+ )
+
+
+@triton.jit
+def _fork_buffer_kernel(
+ src_ptr,
+ dst_ptr,
+ src_idx_ptr,
+ dst_idx_ptr,
+ stride_layer,
+ stride_slot,
+ d_size,
+ num_dst_per_src,
+ BLOCK_D: tl.constexpr,
+):
+ flat_pair = tl.program_id(0)
+ layer_idx = tl.program_id(1)
+ block_d = tl.program_id(2)
+
+ src_chunk = flat_pair // num_dst_per_src
+
+ stride_layer = stride_layer.to(tl.int64)
+ stride_slot = stride_slot.to(tl.int64)
+
+ src_slot = tl.load(src_idx_ptr + src_chunk).to(tl.int64)
+ dst_slot = tl.load(dst_idx_ptr + flat_pair).to(tl.int64)
+
+ offs = block_d * BLOCK_D + tl.arange(0, BLOCK_D)
+ mask = offs < d_size
+
+ base = layer_idx * stride_layer
+ tl.store(
+ dst_ptr + base + dst_slot * stride_slot + offs,
+ tl.load(src_ptr + base + src_slot * stride_slot + offs, mask=mask),
+ mask=mask,
+ )
+
+
+def _get_buffer_copy_configs():
+ configs = []
+ for block_d in [128, 256, 512, 1024, 2048, 4096]:
+ for num_warps in [1, 2, 4, 8]:
+ for num_stages in [1, 2]:
+ configs.append({"BLOCK_D": block_d, "num_warps": num_warps, "num_stages": num_stages})
+ return configs
+
+
+def _get_copy_static_key(
+ src_buffer: torch.Tensor,
+):
+ d_size = (
+ src_buffer.shape[2]
+ if src_buffer.ndim == 3
+ else src_buffer.numel() // (src_buffer.shape[0] * src_buffer.shape[1])
+ )
+ return {
+ "dtype": str(src_buffer.dtype),
+ "d_size": d_size,
+ "layer_num": src_buffer.shape[0],
+ "ndim": src_buffer.ndim,
+ }
+
+
+def _get_copy_run_key(src_buffer: torch.Tensor):
+ return 0
+
+
+def _get_fork_static_key(src_buffer: torch.Tensor):
+ d_size = (
+ src_buffer.shape[2]
+ if src_buffer.ndim == 3
+ else src_buffer.numel() // (src_buffer.shape[0] * src_buffer.shape[1])
+ )
+ return {
+ "dtype": str(src_buffer.dtype),
+ "d_size": d_size,
+ "layer_num": src_buffer.shape[0],
+ "ndim": src_buffer.ndim,
+ }
+
+
+def _get_fork_run_key(src_buffer: torch.Tensor):
+ return 0
+
+
+def _flatten_trailing_dims(buffer: torch.Tensor) -> torch.Tensor:
+ if buffer.ndim == 3:
+ return buffer
+ L, B = buffer.shape[:2]
+ return buffer.view(L, B, -1)
+
+
+@autotune(
+ kernel_name="mamba_buffer_copy_1d:v1",
+ configs_gen_func=_get_buffer_copy_configs,
+ static_key_func=_get_copy_static_key,
+ run_key_func=_get_copy_run_key,
+)
+def _copy_mamba_buffer_autotuned(
+ src_buffer: torch.Tensor,
+ dst_buffer: torch.Tensor,
+ src_indexes: torch.Tensor,
+ dst_indexes: torch.Tensor,
+ run_config: dict = None,
+):
+ if not run_config:
+ d_size = src_buffer.shape[2]
+ BLOCK_D = min(4096, triton.next_power_of_2(d_size))
+ num_warps = 4 if BLOCK_D >= 1024 else 2
+ run_config = {"BLOCK_D": BLOCK_D, "num_warps": num_warps, "num_stages": 1}
+
+ config = run_config
+ BLOCK_D = config["BLOCK_D"]
+ num_pairs = src_indexes.shape[0]
+ layer_num = src_buffer.shape[0]
+ d_size = src_buffer.shape[2]
+
+ num_blocks_d = triton.cdiv(d_size, BLOCK_D)
+
+ grid = (num_pairs, layer_num, num_blocks_d)
+ _copy_buffer_kernel[grid](
+ src_buffer,
+ dst_buffer,
+ src_indexes,
+ dst_indexes,
+ src_buffer.stride(0),
+ src_buffer.stride(1),
+ d_size,
+ BLOCK_D=BLOCK_D,
+ num_warps=config["num_warps"],
+ num_stages=config["num_stages"],
+ )
+
+
+@autotune(
+ kernel_name="mamba_buffer_fork_1d:v1",
+ configs_gen_func=_get_buffer_copy_configs,
+ static_key_func=_get_fork_static_key,
+ run_key_func=_get_fork_run_key,
+)
+def _fork_mamba_buffer_autotuned(
+ src_buffer: torch.Tensor,
+ dst_buffer: torch.Tensor,
+ src_indexes: torch.Tensor,
+ dst_indexes_flat: torch.Tensor,
+ num_dst_per_src: int,
+ run_config: dict = None,
+):
+ if not run_config:
+ d_size = src_buffer.shape[2]
+ BLOCK_D = min(4096, triton.next_power_of_2(d_size))
+ num_warps = 4 if BLOCK_D >= 1024 else 2
+ run_config = {"BLOCK_D": BLOCK_D, "num_warps": num_warps, "num_stages": 1}
+
+ config = run_config
+ BLOCK_D = config["BLOCK_D"]
+ num_src = src_indexes.shape[0]
+ layer_num = src_buffer.shape[0]
+ d_size = src_buffer.shape[2]
+
+ num_blocks_d = triton.cdiv(d_size, BLOCK_D)
+ total_pairs = num_src * num_dst_per_src
+
+ grid = (total_pairs, layer_num, num_blocks_d)
+ _fork_buffer_kernel[grid](
+ src_buffer,
+ dst_buffer,
+ src_indexes,
+ dst_indexes_flat,
+ src_buffer.stride(0),
+ src_buffer.stride(1),
+ d_size,
+ num_dst_per_src,
+ BLOCK_D=BLOCK_D,
+ num_warps=config["num_warps"],
+ num_stages=config["num_stages"],
+ )
+
+
+def copy_mamba_buffer(
+ src_buffer: torch.Tensor,
+ dst_buffer: torch.Tensor,
+ src_indexes: torch.Tensor,
+ dst_indexes: torch.Tensor,
+):
+ assert src_buffer.shape == dst_buffer.shape
+ assert src_indexes.shape == dst_indexes.shape and src_indexes.ndim == 1
+
+ src_flat = _flatten_trailing_dims(src_buffer)
+ dst_flat = _flatten_trailing_dims(dst_buffer)
+ _copy_mamba_buffer_autotuned(src_flat, dst_flat, src_indexes, dst_indexes)
+
+
+def fork_mamba_buffer(
+ src_buffer: torch.Tensor,
+ dst_buffer: torch.Tensor,
+ src_indexes: torch.Tensor,
+ dst_indexes: torch.Tensor,
+):
+ assert src_buffer.shape == dst_buffer.shape
+ assert src_indexes.ndim == 1
+ assert dst_indexes.ndim == 2, f"dst_indexes must be 2D [num_src, num_dst_per_src], got {dst_indexes.shape}"
+ assert (
+ dst_indexes.shape[0] == src_indexes.shape[0]
+ ), f"Mismatch: src_indexes {src_indexes.shape[0]} vs dst_indexes rows {dst_indexes.shape[0]}"
+
+ num_dst_per_src = dst_indexes.shape[1]
+ dst_indexes_flat = dst_indexes.reshape(-1).contiguous()
+
+ src_flat = _flatten_trailing_dims(src_buffer)
+ dst_flat = _flatten_trailing_dims(dst_buffer)
+ _fork_mamba_buffer_autotuned(src_flat, dst_flat, src_indexes, dst_indexes_flat, num_dst_per_src)
diff --git a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py
index fab0141158..e152a8dd83 100644
--- a/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py
+++ b/lightllm/common/basemodel/triton_kernel/norm/qk_norm.py
@@ -78,11 +78,12 @@ def _qk_rms_norm_fused_kernel(
WK_ptr,
stride_k_row,
stride_k_col,
+ eps,
# Dimensions
num_heads_q: tl.constexpr, # Q 的头数 (用于判断边界)
head_dim: tl.constexpr,
- eps,
BLOCK_SIZE: tl.constexpr,
+ FP32_MULTIPLY: tl.constexpr,
):
# PID 0: 处理第几个 Token (Row)
row_idx = tl.program_id(0)
@@ -108,13 +109,15 @@ def _qk_rms_norm_fused_kernel(
# RMSNorm 计算
var = tl.sum(x * x, axis=0) / head_dim
rstd = 1 / tl.sqrt(var + eps)
+ x *= rstd
# 加载 Q 的权重 (假设所有 Head 共享同一组 dim=head_dim 的权重)
w = tl.load(WQ_ptr + offs)
-
- x *= rstd
- y = x.to(w.dtype) * w
-
+ if FP32_MULTIPLY:
+ w = w.to(tl.float32)
+ else:
+ x = x.to(Q_ptr.dtype.element_ty)
+ y = (x * w).to(Q_ptr.dtype.element_ty)
# 写回 Q
tl.store(Q_ptr + q_ptr_offset, y)
@@ -132,18 +135,27 @@ def _qk_rms_norm_fused_kernel(
# RMSNorm 计算
var = tl.sum(x * x, axis=0) / head_dim
rstd = 1 / tl.sqrt(var + eps)
+ x *= rstd
# 加载 K 的权重
w = tl.load(WK_ptr + offs)
- x *= rstd
-
- y = x.to(w.dtype) * w
-
+ if FP32_MULTIPLY:
+ w = w.to(tl.float32)
+ else:
+ x = x.to(K_ptr.dtype.element_ty)
+ y = (x * w).to(K_ptr.dtype.element_ty)
# 写回 K
tl.store(K_ptr + k_ptr_offset, y)
-def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor, w_k: torch.Tensor, eps: float = 1e-6):
+def qk_rmsnorm_fused_forward(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ w_q: torch.Tensor,
+ w_k: torch.Tensor,
+ eps: float = 1e-6,
+ fp32_multiply: bool = False,
+):
"""
In-place RMSNorm for both Q and K in a single kernel launch.
Supports GQA (different number of heads for Q and K).
@@ -197,6 +209,7 @@ def qk_rmsnorm_fused_forward(q: torch.Tensor, k: torch.Tensor, w_q: torch.Tensor
num_heads_q=num_heads_q,
head_dim=head_dim,
eps=eps,
+ FP32_MULTIPLY=fp32_multiply,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=4,
)
diff --git a/lightllm/common/mamba_cache_mem_manager/cache_manager.py b/lightllm/common/mamba_cache_mem_manager/cache_manager.py
new file mode 100644
index 0000000000..9d2d372e17
--- /dev/null
+++ b/lightllm/common/mamba_cache_mem_manager/cache_manager.py
@@ -0,0 +1,207 @@
+from typing import List, Tuple, Union
+
+import torch
+import numpy as np
+
+from lightllm.utils.dist_utils import get_current_rank_in_node
+from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
+from lightllm.common.basemodel.triton_kernel.mamba_buffer_copy import copy_mamba_buffer, fork_mamba_buffer
+from lightllm.utils.log_utils import init_logger
+from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
+
+logger = init_logger(__name__)
+
+MAMBA_CACHE_CAN_USE_NUM_SHM_NAME = f"{get_unique_server_name()}_mamba_cache_can_use_num"
+
+
+class LayerCache:
+ def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_num: int):
+ self.size = size
+ self.dtype = dtype
+ self.shape = shape
+ self.layer_num = layer_num
+
+ self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda")
+
+ def get_cell_size(self):
+ return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype)
+
+
+class MambaCacheManager:
+ def __init__(
+ self,
+ size: int,
+ layer_num: int,
+ conv_state_dtype: torch.dtype,
+ conv_state_shape: Tuple[int, ...],
+ ssm_state_dtype: torch.dtype,
+ ssm_state_shape: Tuple[int, ...],
+ ):
+ # init the mem state
+ self.size = size
+ self.mem_state = torch.arange(
+ 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
+ )
+ self._mem_state_return = torch.arange(
+ 0, self.size * 3, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
+ )
+ self._return_start = 0
+ self.mark_start = 0
+ self.mark_end = self.size
+ self.can_use_mem_size = self.size
+ self.shared_can_use_token_num = SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{get_current_rank_in_node()}")
+ self.shared_can_use_token_num.set_value(self.can_use_mem_size)
+ self.HOLD_TOKEN_MEMINDEX = self.size
+
+ # init the layer cache
+ self.conv_state_cache = LayerCache(size, conv_state_dtype, conv_state_shape, layer_num)
+ self.ssm_state_cache = LayerCache(size, ssm_state_dtype, ssm_state_shape, layer_num)
+ self.HOLD_BUFFER_INDEX = size
+
+ logger.warning(
+ f"Linear attention state cache size: {size}\n"
+ f"Conv state use : "
+ f"{self.conv_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n"
+ f"Ssm state use : "
+ f"{self.ssm_state_cache.get_cell_size() * size / 1024 ** 3} GB Memory.\n"
+ )
+
+ def get_mamba_cache(self, layer_idx: int):
+ conv_state = self.conv_state_cache.buffer[layer_idx]
+ ssm_state = self.ssm_state_cache.buffer[layer_idx]
+ return conv_state, ssm_state
+
+ def copy_state_buffers(self, src_buffer_indexes: torch.Tensor, dst_buffer_indexes: torch.Tensor):
+ copy_mamba_buffer(
+ self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes
+ )
+ copy_mamba_buffer(
+ self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_indexes, dst_buffer_indexes
+ )
+
+ def fork_state_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor):
+ fork_mamba_buffer(
+ self.conv_state_cache.buffer, self.conv_state_cache.buffer, src_buffer_index, dst_buffer_indexes
+ )
+ fork_mamba_buffer(
+ self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes
+ )
+
+ def fork_ssm_buffers(self, src_buffer_index: torch.Tensor, dst_buffer_indexes: torch.Tensor):
+ """
+ Fork ONLY SSM states (not conv states) from source indices to destination indices.
+
+ This is used for MTP mode where each buffer maintains its own independent conv state,
+ but SSM states need to be synchronized.
+ """
+ fork_mamba_buffer(
+ self.ssm_state_cache.buffer, self.ssm_state_cache.buffer, src_buffer_index, dst_buffer_indexes
+ )
+
+ def alloc(self, need_size) -> torch.Tensor:
+ if need_size > self.mark_end - self.mark_start:
+ logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}")
+ assert False, "error alloc state"
+
+ start = self.mark_start
+ end = self.mark_start + need_size
+ self.mark_start += need_size
+
+ self.can_use_mem_size -= need_size
+ self.shared_can_use_token_num.set_value(self.can_use_mem_size)
+
+ # 利用缓冲区返回,避免异步情况下的内存竞争
+ if self._return_start + need_size > self._mem_state_return.shape[0]:
+ self._return_start = 0
+ ans = self._mem_state_return[self._return_start : self._return_start + need_size]
+ ans.copy_(self.mem_state[start:end])
+ self._return_start += need_size
+ return ans
+
+ def free(self, free_index: Union[torch.Tensor, List[int]]):
+ """
+ Free the allocated cache buffers and clear them.
+
+ Args:
+ free_index: Buffer indices to free (tensor or list of ints)
+ """
+ # Convert to tensor if needed for indexing
+ if isinstance(free_index, list):
+ free_index_tensor = torch.tensor(free_index, dtype=torch.long, device="cuda")
+ else:
+ free_index_tensor = free_index.to(device="cuda", dtype=torch.long)
+
+ # Clear the buffers for the freed indices
+ # Shape: [layer_num, buffer_index, *shape]
+ self.conv_state_cache.buffer[:, free_index_tensor, ...] = 0
+ self.ssm_state_cache.buffer[:, free_index_tensor, ...] = 0
+
+ # update the mem state
+ end = self.mark_start
+ start = self.mark_start - len(free_index)
+ assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}"
+
+ if isinstance(free_index, list):
+ free_index_tensor = torch.tensor(free_index, dtype=self.mem_state.dtype, device=self.mem_state.device)
+ self.mem_state[start:end] = free_index_tensor
+ else:
+ # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作
+ self.mem_state[start:end] = free_index
+
+ self.mark_start -= len(free_index)
+
+ self.can_use_mem_size += len(free_index)
+ self.shared_can_use_token_num.set_value(self.can_use_mem_size)
+
+ if self.can_use_mem_size == len(self.mem_state):
+ logger.debug(f"freed all gpu mem size {self.can_use_mem_size}")
+
+ return
+
+ def free_all(self):
+ self.conv_state_cache.buffer.fill_(0)
+ self.ssm_state_cache.buffer.fill_(0)
+ self.can_use_mem_size = len(self.mem_state)
+ self.shared_can_use_token_num.set_value(self.can_use_mem_size)
+ self.mem_state.numpy()[:] = list(range(0, len(self.mem_state)))
+ self.mark_start = 0
+ self.mark_end = len(self.mem_state)
+
+ return
+
+ def resize_mem(self, new_size):
+ """
+ just for test code
+ """
+ self.size = new_size
+ self.mem_state = torch.arange(
+ 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True
+ )
+ self.mark_start = 0
+ self.mark_end = self.size
+ self.can_use_mem_size = self.size
+ self.shared_can_use_token_num.set_value(self.can_use_mem_size)
+ return
+
+
+class ReadOnlyStaticsMambaCacheManager:
+ """
+ 读取一些统计信息
+ """
+
+ def __init__(self) -> None:
+ args = get_env_start_args()
+ self.global_world_size = args.tp
+ self.node_world_size = args.tp // args.nnodes
+ self.dp_world_size = self.global_world_size // args.dp
+ # 兼容多机 dp size=1 纯 tp 模式的情况
+ self.is_multinode_tp = args.dp == 1 and args.nnodes > 1
+ self.shared_tp_can_use_token_nums = [
+ SharedInt(f"{MAMBA_CACHE_CAN_USE_NUM_SHM_NAME}_{rank_in_node}")
+ for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
+ ]
+
+ def get_unrefed_token_num(self, dp_rank_in_node: int):
+ if self.is_multinode_tp:
+ return self.shared_tp_can_use_token_nums[0].get_value()
+ return self.shared_tp_can_use_token_nums[dp_rank_in_node].get_value()
diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py
index 33bdca4475..bbe2bb4a3b 100644
--- a/lightllm/common/req_manager.py
+++ b/lightllm/common/req_manager.py
@@ -3,6 +3,7 @@
from lightllm.utils.log_utils import init_logger
from .kv_cache_mem_manager import MemoryManager
from typing import List, Optional
+
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter
from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter
from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args
@@ -67,6 +68,7 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana
self.req_sampling_params_manager = ReqSamplingParamsManager(max_request_num)
self.max_request_num = max_request_num
self.HOLD_REQUEST_ID = max_request_num
+ self.req_to_buffer_index = None
def alloc(self):
return self.req_list.alloc()
@@ -93,6 +95,11 @@ def free_all(self):
self.req_list = _ReqLinkedList(self.max_request_num)
return
+ @property
+ def has_recurrent_state(self):
+ """Whether this model uses per-request recurrent state buffers (e.g. Mamba/linear attention)."""
+ return self.req_to_buffer_index is not None
+
class ReqSamplingParamsManager:
"""
@@ -232,3 +239,28 @@ def gen_cpu_out_token_counter_sampling_params(self, req_objs: List):
p_token_counts_tensor.cuda(non_blocking=True),
p_cumsum_seq_len_tensor.cuda(non_blocking=True),
)
+
+
+class ReqManagerForMamba(ReqManager):
+ def __init__(self, max_request_num, max_sequence_length, mem_manager):
+ from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager
+
+ super().__init__(max_request_num, max_sequence_length, mem_manager)
+ self.mtp_step = get_env_start_args().mtp_step
+ self.buffer_mem_manager: MambaCacheManager = self.mem_manager.mamba_cache_mem_manager
+ self.req_to_buffer_index = torch.zeros(
+ (self.max_request_num + 1, self.mtp_step + 1), dtype=torch.int32, device="cuda"
+ )
+ self.req_to_buffer_index[self.HOLD_REQUEST_ID, :] = self.buffer_mem_manager.HOLD_BUFFER_INDEX
+
+ def free_buffer(self, free_buffer_indexes: List[int]):
+ self.buffer_mem_manager.free(free_buffer_indexes)
+ return
+
+ def alloc_buffer_for_req(self, req_index: torch.Tensor):
+ num_reqs = req_index.shape[0]
+ num_buffers_per_req = self.mtp_step + 1
+ buffer_indexes = self.buffer_mem_manager.alloc(num_reqs * num_buffers_per_req)
+ if not buffer_indexes.is_cuda:
+ buffer_indexes = buffer_indexes.cuda()
+ self.req_to_buffer_index[req_index] = buffer_indexes.view(num_reqs, num_buffers_per_req)
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..cc5c68eb79
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,8 @@
+{
+ "4": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..b6e5109b62
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 4,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..511935b4cf
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,14 @@
+{
+ "2": {
+ "BK": 128,
+ "BV": 128,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "4": {
+ "BK": 128,
+ "BV": 128,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..cc5c68eb79
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,8 @@
+{
+ "4": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..1038611f6a
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..4bc06d07d9
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,14 @@
+{
+ "2": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..7421097fa4
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,8 @@
+{
+ "4": {
+ "BK": 64,
+ "BV": 128,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..f1159e4357
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 64,
+ "BV": 128,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..892c20e78d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,14 @@
+{
+ "2": {
+ "BK": 64,
+ "BV": 128,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4": {
+ "BK": 64,
+ "BV": 128,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..d831f32c4a
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=12,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "4": {
+ "BV": 32,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..2af1b86e90
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "8": {
+ "BV": 32,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..40cdc996b9
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,12 @@
+{
+ "2": {
+ "BV": 32,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4": {
+ "BV": 32,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..833062ec2f
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=12,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,38 @@
+{
+ "1": {
+ "num_warps": 2
+ },
+ "100": {
+ "num_warps": 1
+ },
+ "1024": {
+ "num_warps": 1
+ },
+ "128": {
+ "num_warps": 1
+ },
+ "16": {
+ "num_warps": 2
+ },
+ "16384": {
+ "num_warps": 1
+ },
+ "2048": {
+ "num_warps": 1
+ },
+ "256": {
+ "num_warps": 1
+ },
+ "32": {
+ "num_warps": 1
+ },
+ "4096": {
+ "num_warps": 2
+ },
+ "64": {
+ "num_warps": 8
+ },
+ "8": {
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..5f2cf9465b
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,38 @@
+{
+ "1": {
+ "num_warps": 8
+ },
+ "100": {
+ "num_warps": 8
+ },
+ "1024": {
+ "num_warps": 2
+ },
+ "128": {
+ "num_warps": 8
+ },
+ "16": {
+ "num_warps": 8
+ },
+ "16384": {
+ "num_warps": 2
+ },
+ "2048": {
+ "num_warps": 2
+ },
+ "256": {
+ "num_warps": 2
+ },
+ "32": {
+ "num_warps": 8
+ },
+ "4096": {
+ "num_warps": 2
+ },
+ "64": {
+ "num_warps": 2
+ },
+ "8": {
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..c8a1841674
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,38 @@
+{
+ "1": {
+ "num_warps": 8
+ },
+ "100": {
+ "num_warps": 2
+ },
+ "1024": {
+ "num_warps": 8
+ },
+ "128": {
+ "num_warps": 8
+ },
+ "16": {
+ "num_warps": 8
+ },
+ "16384": {
+ "num_warps": 1
+ },
+ "2048": {
+ "num_warps": 8
+ },
+ "256": {
+ "num_warps": 4
+ },
+ "32": {
+ "num_warps": 2
+ },
+ "4096": {
+ "num_warps": 8
+ },
+ "64": {
+ "num_warps": 1
+ },
+ "8": {
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..a97cabf8b2
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=12,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "4": {
+ "BK": 64,
+ "num_stages": 3,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..786624883f
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,7 @@
+{
+ "8": {
+ "BK": 64,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..eaca03cf75
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,12 @@
+{
+ "2": {
+ "BK": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4": {
+ "BK": 64,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..d9064e5d6a
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=12,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLK_HEADS": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "1024": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "128": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "16": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "256": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "32": {
+ "BLK_HEADS": 8,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "64": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "8": {
+ "BLK_HEADS": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..baef19d90c
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLK_HEADS": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "128": {
+ "BLK_HEADS": 64,
+ "num_warps": 1
+ },
+ "16": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLK_HEADS": 32,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLK_HEADS": 16,
+ "num_warps": 2
+ },
+ "256": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "32": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLK_HEADS": 16,
+ "num_warps": 2
+ },
+ "64": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "8": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..90ac24c408
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "100": {
+ "BLK_HEADS": 4,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLK_HEADS": 8,
+ "num_warps": 2
+ },
+ "128": {
+ "BLK_HEADS": 64,
+ "num_warps": 1
+ },
+ "16": {
+ "BLK_HEADS": 8,
+ "num_warps": 2
+ },
+ "16384": {
+ "BLK_HEADS": 16,
+ "num_warps": 2
+ },
+ "2048": {
+ "BLK_HEADS": 64,
+ "num_warps": 1
+ },
+ "256": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "32": {
+ "BLK_HEADS": 4,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "64": {
+ "BLK_HEADS": 8,
+ "num_warps": 1
+ },
+ "8": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..31d7a6e203
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,118 @@
+{
+ "1024": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "12": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "1200": {
+ "BLOCK_N": 256,
+ "num_warps": 2
+ },
+ "12288": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_N": 256,
+ "num_warps": 8
+ },
+ "131072": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "1536": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "1600": {
+ "BLOCK_N": 64,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "192": {
+ "BLOCK_N": 512,
+ "num_warps": 1
+ },
+ "196608": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_N": 64,
+ "num_warps": 1
+ },
+ "24576": {
+ "BLOCK_N": 64,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "262144": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "3072": {
+ "BLOCK_N": 256,
+ "num_warps": 2
+ },
+ "32768": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "384": {
+ "BLOCK_N": 512,
+ "num_warps": 2
+ },
+ "4096": {
+ "BLOCK_N": 256,
+ "num_warps": 1
+ },
+ "49152": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "512": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "64": {
+ "BLOCK_N": 256,
+ "num_warps": 8
+ },
+ "65536": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "768": {
+ "BLOCK_N": 256,
+ "num_warps": 2
+ },
+ "8": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "800": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "8192": {
+ "BLOCK_N": 256,
+ "num_warps": 1
+ },
+ "96": {
+ "BLOCK_N": 512,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..0042ef8a2a
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "131072": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..54a5967071
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=128,N=3072,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "131072": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..bb78d1dd84
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..1552d8bf1a
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..08cbfd85c3
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "131072": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "512": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..13a070b8f0
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/grouped_matmul:v1/{K=3072,N=256,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..169a148799
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_align_fused:v1/{topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE": 512,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..5022588ef5
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 2,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 2,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ },
+ "64": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..4ae96d02d1
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/moe_sum_reduce:v1/{hidden_dim=3072,out_dtype=torch.bfloat16,topk_num=8}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 32,
+ "NUM_STAGE": 2,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 1
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 16
+ },
+ "16384": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 16
+ },
+ "4096": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 2,
+ "NUM_STAGE": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 16
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..28c654f3d2
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "100": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "128": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "16": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "16384": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "2048": {
+ "num_stages": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "32": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "4096": {
+ "num_stages": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "num_stages": 2,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..08b0d5e5bc
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=6,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "1024": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "128": {
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16": {
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "16384": {
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2048": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "256": {
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "32": {
+ "num_stages": 5,
+ "num_warps": 8
+ },
+ "4096": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "64": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "8": {
+ "num_stages": 5,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..0d871841ed
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "num_stages": 1,
+ "num_warps": 1
+ },
+ "1024": {
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "128": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "16": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "16384": {
+ "num_stages": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "num_stages": 4,
+ "num_warps": 1
+ },
+ "32": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "4096": {
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "64": {
+ "num_stages": 4,
+ "num_warps": 1
+ },
+ "8": {
+ "num_stages": 3,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..9f3a8dcb25
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "131072": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "32768": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
new file mode 100644
index 0000000000..72026f01c4
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H100_80GB_HBM3/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H100_80GB_HBM3.json
@@ -0,0 +1,74 @@
+{
+ "1024": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "131072": {
+ "BLOCK_M": 64,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "8": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "800": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..5d9216c2ea
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=16,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 128,
+ "BV": 128,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..131da59770
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=24,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..338af08a1d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,14 @@
+{
+ "2": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "4": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..131da59770
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=16,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..131da59770
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=24,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..4bc06d07d9
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,14 @@
+{
+ "2": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..f1159e4357
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 64,
+ "BV": 128,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..131da59770
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,8 @@
+{
+ "8": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..c8fa422e0c
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,14 @@
+{
+ "2": {
+ "BK": 128,
+ "BV": 64,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4": {
+ "BK": 64,
+ "BV": 128,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..2af1b86e90
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=16,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "8": {
+ "BV": 32,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..2af1b86e90
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=24,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "8": {
+ "BV": 32,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..40cdc996b9
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json
@@ -0,0 +1,12 @@
+{
+ "2": {
+ "BV": 32,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4": {
+ "BV": 32,
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..a40eda35d4
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=16,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json
@@ -0,0 +1,38 @@
+{
+ "1": {
+ "num_warps": 8
+ },
+ "100": {
+ "num_warps": 1
+ },
+ "1024": {
+ "num_warps": 8
+ },
+ "128": {
+ "num_warps": 8
+ },
+ "16": {
+ "num_warps": 8
+ },
+ "16384": {
+ "num_warps": 1
+ },
+ "2048": {
+ "num_warps": 1
+ },
+ "256": {
+ "num_warps": 1
+ },
+ "32": {
+ "num_warps": 8
+ },
+ "4096": {
+ "num_warps": 1
+ },
+ "64": {
+ "num_warps": 1
+ },
+ "8": {
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..808ed9a7fc
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=24,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json
@@ -0,0 +1,38 @@
+{
+ "1": {
+ "num_warps": 1
+ },
+ "100": {
+ "num_warps": 8
+ },
+ "1024": {
+ "num_warps": 2
+ },
+ "128": {
+ "num_warps": 8
+ },
+ "16": {
+ "num_warps": 8
+ },
+ "16384": {
+ "num_warps": 1
+ },
+ "2048": {
+ "num_warps": 8
+ },
+ "256": {
+ "num_warps": 8
+ },
+ "32": {
+ "num_warps": 8
+ },
+ "4096": {
+ "num_warps": 2
+ },
+ "64": {
+ "num_warps": 1
+ },
+ "8": {
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..5b08208be2
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json
@@ -0,0 +1,38 @@
+{
+ "1": {
+ "num_warps": 1
+ },
+ "100": {
+ "num_warps": 8
+ },
+ "1024": {
+ "num_warps": 1
+ },
+ "128": {
+ "num_warps": 1
+ },
+ "16": {
+ "num_warps": 1
+ },
+ "16384": {
+ "num_warps": 2
+ },
+ "2048": {
+ "num_warps": 8
+ },
+ "256": {
+ "num_warps": 8
+ },
+ "32": {
+ "num_warps": 8
+ },
+ "4096": {
+ "num_warps": 8
+ },
+ "64": {
+ "num_warps": 1
+ },
+ "8": {
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..27e4804a61
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=16,IS_VARLEN=true,K=128}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "8": {
+ "BK": 64,
+ "num_stages": 2,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..fb62cf8259
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=24,IS_VARLEN=true,K=128}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "8": {
+ "BK": 32,
+ "num_stages": 3,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json
new file mode 100644
index 0000000000..7749b3601f
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json
@@ -0,0 +1,12 @@
+{
+ "2": {
+ "BK": 64,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "4": {
+ "BK": 64,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..49c4dc63d1
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=16,a_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "100": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "128": {
+ "BLK_HEADS": 8,
+ "num_warps": 2
+ },
+ "16": {
+ "BLK_HEADS": 8,
+ "num_warps": 2
+ },
+ "16384": {
+ "BLK_HEADS": 16,
+ "num_warps": 2
+ },
+ "2048": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "256": {
+ "BLK_HEADS": 8,
+ "num_warps": 2
+ },
+ "32": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "64": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "8": {
+ "BLK_HEADS": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..ad8d397d3b
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=24,a_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLK_HEADS": 8,
+ "num_warps": 4
+ },
+ "100": {
+ "BLK_HEADS": 16,
+ "num_warps": 2
+ },
+ "1024": {
+ "BLK_HEADS": 32,
+ "num_warps": 1
+ },
+ "128": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "16": {
+ "BLK_HEADS": 64,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "256": {
+ "BLK_HEADS": 64,
+ "num_warps": 2
+ },
+ "32": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "64": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "8": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..907575d960
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "100": {
+ "BLK_HEADS": 4,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "128": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "16": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "256": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "32": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLK_HEADS": 16,
+ "num_warps": 2
+ },
+ "64": {
+ "BLK_HEADS": 16,
+ "num_warps": 1
+ },
+ "8": {
+ "BLK_HEADS": 64,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..55ccb24a65
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,118 @@
+{
+ "1024": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "128": {
+ "BLOCK_N": 512,
+ "num_warps": 2
+ },
+ "131072": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "1536": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "16": {
+ "BLOCK_N": 256,
+ "num_warps": 4
+ },
+ "1600": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "192": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "2048": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "24": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "2400": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "24576": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "256": {
+ "BLOCK_N": 512,
+ "num_warps": 2
+ },
+ "262144": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "3072": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "32768": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "384": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "393216": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "4096": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "49152": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "512": {
+ "BLOCK_N": 256,
+ "num_warps": 4
+ },
+ "6144": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "65536": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ },
+ "768": {
+ "BLOCK_N": 256,
+ "num_warps": 2
+ },
+ "8": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "800": {
+ "BLOCK_N": 64,
+ "num_warps": 2
+ },
+ "8192": {
+ "BLOCK_N": 128,
+ "num_warps": 2
+ },
+ "98304": {
+ "BLOCK_N": 128,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..e5a383f23f
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "10": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1000": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "10240": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1280": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "160": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "163840": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "20480": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2560": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "320": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "40960": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "640": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "80": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..56c79e3a43
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=128,N=4096,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "10": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1000": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "10240": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1280": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "160": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "163840": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "20480": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "2560": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "320": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "40960": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "640": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "80": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..4843ed8ccf
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..662875ecdb
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=256,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=8,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,38 @@
+{
+ "1024": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..3c0e605b00
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=512,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..1f8134fa64
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=256,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,38 @@
+{
+ "131072": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32768": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8192": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..d82ca44a21
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=256,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "10": {
+ "BLOCK_SIZE_K": 32,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1000": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "10240": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1280": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "160": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "163840": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "20480": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2560": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 8
+ },
+ "320": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "40960": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "640": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "80": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 64,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json
new file mode 100644
index 0000000000..96eabffc42
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/grouped_matmul:v1/{K=4096,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json
@@ -0,0 +1,110 @@
+{
+ "1": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 32,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 32,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 32,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "2048": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 64,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 2,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 64,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 5,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_SIZE_K": 64,
+ "BLOCK_SIZE_M": 128,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ },
+ "64": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 1,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "BLOCK_SIZE_K": 128,
+ "BLOCK_SIZE_M": 16,
+ "BLOCK_SIZE_N": 128,
+ "GROUP_SIZE_M": 16,
+ "NEED_TRANS": false,
+ "num_stages": 3,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json
new file mode 100644
index 0000000000..69a0e9ca42
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "0": {
+ "BLOCK_D": 2048,
+ "num_stages": 2,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json
new file mode 100644
index 0000000000..9de6716c3c
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "0": {
+ "BLOCK_D": 4096,
+ "num_stages": 1,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json
new file mode 100644
index 0000000000..2e3a3febbb
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "0": {
+ "BLOCK_D": 512,
+ "num_stages": 1,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json
new file mode 100644
index 0000000000..5c9f40590b
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_copy_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "0": {
+ "BLOCK_D": 1024,
+ "num_stages": 1,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json
new file mode 100644
index 0000000000..0a3facae38
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=12288,dtype=torch.bfloat16,layer_num=30,ndim=3}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "0": {
+ "BLOCK_D": 2048,
+ "num_stages": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json
new file mode 100644
index 0000000000..9cdaab5ace
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=15360,dtype=torch.bfloat16,layer_num=48,ndim=3}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "0": {
+ "BLOCK_D": 4096,
+ "num_stages": 2,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json
new file mode 100644
index 0000000000..2e3a3febbb
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=262144,dtype=torch.float32,layer_num=30,ndim=3}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "0": {
+ "BLOCK_D": 512,
+ "num_stages": 1,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json
new file mode 100644
index 0000000000..889f6ab71b
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mamba_buffer_fork_1d:v1/{d_size=393216,dtype=torch.float32,layer_num=48,ndim=3}_NVIDIA_H200.json
@@ -0,0 +1,7 @@
+{
+ "0": {
+ "BLOCK_D": 1024,
+ "num_stages": 2,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json
new file mode 100644
index 0000000000..07e5e6875f
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 2
+ },
+ "128": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "16384": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 2
+ },
+ "256": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "4096": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 4
+ },
+ "64": {
+ "BLOCK_SIZE": 128,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_SIZE": 512,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json
new file mode 100644
index 0000000000..ff4632955f
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 4,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "100": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "1024": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 16
+ },
+ "16384": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "2048": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "256": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "32": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 2
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 64,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json
new file mode 100644
index 0000000000..89ab51ff8c
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=4096,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "1": {
+ "BLOCK_DIM": 128,
+ "BLOCK_M": 4,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "100": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "1024": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "128": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "16": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "16384": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 2
+ },
+ "2048": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "256": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "32": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ },
+ "4096": {
+ "BLOCK_DIM": 512,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 4,
+ "num_warps": 1
+ },
+ "64": {
+ "BLOCK_DIM": 1024,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 8
+ },
+ "8": {
+ "BLOCK_DIM": 256,
+ "BLOCK_M": 1,
+ "NUM_STAGE": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..f4d29554da
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=256,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "num_stages": 1,
+ "num_warps": 8
+ },
+ "100": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "1024": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "128": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16": {
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "16384": {
+ "num_stages": 1,
+ "num_warps": 2
+ },
+ "2048": {
+ "num_stages": 2,
+ "num_warps": 2
+ },
+ "256": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "32": {
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "4096": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "8": {
+ "num_stages": 5,
+ "num_warps": 8
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..8605a91680
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=4,dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "100": {
+ "num_stages": 4,
+ "num_warps": 4
+ },
+ "1024": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "128": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "16": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "16384": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "2048": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "256": {
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "32": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "4096": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "64": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "8": {
+ "num_stages": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..5b3e656b6d
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=1,Q_HEAD_NUM=8,dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,18 @@
+{
+ "1024": {
+ "num_stages": 5,
+ "num_warps": 1
+ },
+ "16384": {
+ "num_stages": 4,
+ "num_warps": 2
+ },
+ "2048": {
+ "num_stages": 2,
+ "num_warps": 2
+ },
+ "4096": {
+ "num_stages": 3,
+ "num_warps": 2
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..ada783ef92
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/mrope_triton_fused:v1/{HEAD_DIM=64,K_HEAD_NUM=2,Q_HEAD_NUM=12,dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,50 @@
+{
+ "1": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "100": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "1024": {
+ "num_stages": 2,
+ "num_warps": 1
+ },
+ "128": {
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "16": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "16384": {
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "2048": {
+ "num_stages": 5,
+ "num_warps": 2
+ },
+ "256": {
+ "num_stages": 3,
+ "num_warps": 2
+ },
+ "32": {
+ "num_stages": 3,
+ "num_warps": 1
+ },
+ "4096": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "64": {
+ "num_stages": 3,
+ "num_warps": 4
+ },
+ "8": {
+ "num_stages": 3,
+ "num_warps": 1
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..12993b0231
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,74 @@
+{
+ "10": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 4,
+ "num_warps": 8
+ },
+ "1000": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "10240": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "1280": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 4
+ },
+ "160": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "163840": {
+ "BLOCK_M": 64,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "20480": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2560": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 64,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "320": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "40960": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "640": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "80": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json
new file mode 100644
index 0000000000..0a0f01fe7a
--- /dev/null
+++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.5.1/NVIDIA_H200/silu_and_mul_fwd:v1/{N=256,out_dtype=torch.bfloat16}_NVIDIA_H200.json
@@ -0,0 +1,98 @@
+{
+ "10": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "1000": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "10240": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "1280": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "131072": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "160": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "16384": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "163840": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "20480": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "2560": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 4
+ },
+ "320": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "32768": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "40960": {
+ "BLOCK_M": 32,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 4,
+ "num_warps": 1
+ },
+ "640": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 1
+ },
+ "80": {
+ "BLOCK_M": 1,
+ "BLOCK_N": 128,
+ "NUM_STAGES": 2,
+ "num_warps": 1
+ },
+ "8192": {
+ "BLOCK_M": 8,
+ "BLOCK_N": 256,
+ "NUM_STAGES": 1,
+ "num_warps": 4
+ }
+}
\ No newline at end of file
diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py
index f2e29d4a88..2caee91709 100644
--- a/lightllm/models/__init__.py
+++ b/lightllm/models/__init__.py
@@ -7,6 +7,7 @@
from lightllm.models.qwen2.model import Qwen2TpPartModel
from lightllm.models.qwen3.model import Qwen3TpPartModel
from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.models.qwen3next.model import Qwen3NextTpPartModel
from lightllm.models.internlm.model import InternlmTpPartModel
from lightllm.models.stablelm.model import StablelmTpPartModel
from lightllm.models.internlm2.model import Internlm2TpPartModel
@@ -39,4 +40,6 @@
)
from lightllm.models.gpt_oss.model import GptOssTpPartModel
from lightllm.models.qwen3_omni_moe_thinker.model import Qwen3OmniMOETpPartModel
+from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel
+from lightllm.models.qwen3_5_moe.model import Qwen3_5MOETpPartModel
from .registry import get_model, get_model_class
diff --git a/lightllm/models/llama/layer_weights/transformer_layer_weight.py b/lightllm/models/llama/layer_weights/transformer_layer_weight.py
index 0566c9f1c6..e42a6191e6 100644
--- a/lightllm/models/llama/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/llama/layer_weights/transformer_layer_weight.py
@@ -31,7 +31,7 @@ def _parse_config(self):
head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"]
self.head_dim = self.network_config_.get("head_dim", head_dim)
self.n_embed = self.network_config_["hidden_size"]
- self.n_inter = self.network_config_["intermediate_size"]
+ self.n_inter = self.network_config_.get("intermediate_size", -1)
def _init_weight_names(self):
self._q_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"
diff --git a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py
index 7156a5ce23..825a985b46 100644
--- a/lightllm/models/qwen2_5_vl/qwen2_5_visual.py
+++ b/lightllm/models/qwen2_5_vl/qwen2_5_visual.py
@@ -227,14 +227,14 @@ def _init_datatype(self):
def rot_pos_emb(self, grid_thw):
pos_ids = []
s = self.spatial_merge_size
- for _, h, w in grid_thw:
+ for t, h, w in grid_thw:
pos_shape = (h // s, s, w // s, s)
hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
hpos_ids = hpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
wpos_ids = wpos_ids.reshape(pos_shape).permute(0, 2, 1, 3).flatten()
- pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1))
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
pos_ids = torch.cat(pos_ids, dim=0)
max_grid_size = grid_thw[:, 1:].max()
cos_full, sin_full = self.rotary_pos_emb(max_grid_size)
diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py
index a29cb8758b..6076756043 100644
--- a/lightllm/models/qwen2_vl/qwen2_visual.py
+++ b/lightllm/models/qwen2_vl/qwen2_visual.py
@@ -57,6 +57,8 @@ def __init__(
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=False)
+ # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup)
+ self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = hidden_states.view(
diff --git a/lightllm/models/qwen2_vl/triton_kernel/mrope.py b/lightllm/models/qwen2_vl/triton_kernel/mrope.py
index 5aed658626..e488a0ce30 100644
--- a/lightllm/models/qwen2_vl/triton_kernel/mrope.py
+++ b/lightllm/models/qwen2_vl/triton_kernel/mrope.py
@@ -193,10 +193,11 @@ def mrope_triton_fused(
sin: torch.Tensor,
mrope_section: torch.Tensor,
is_interleaved: bool,
+ partial_rotary_factor: float = 1.0,
run_config: Optional[dict] = None,
):
head_num_q, head_num_k = q.shape[1], k.shape[1]
- head_dim = int(q.shape[2])
+ head_dim = int(q.shape[2] * partial_rotary_factor)
num_tokens = q.shape[0]
if not run_config:
diff --git a/lightllm/models/qwen2_vl/vision_process.py b/lightllm/models/qwen2_vl/vision_process.py
index f2cd38ec8e..bc313fe467 100644
--- a/lightllm/models/qwen2_vl/vision_process.py
+++ b/lightllm/models/qwen2_vl/vision_process.py
@@ -187,7 +187,10 @@ def _preprocess_bydevice(self, image, device="cuda") -> Tuple[torch.Tensor, torc
if image.mode != "RGB":
image = image.convert("RGB")
image_arr = np.asarray(image, dtype=np.uint8)
- image_data = torch.from_numpy(image_arr).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True)
+ # Copy to ensure writable array (avoids PyTorch warning for read-only NumPy arrays)
+ image_data = (
+ torch.from_numpy(image_arr.copy()).permute(2, 0, 1).contiguous().to(device=device, non_blocking=True)
+ )
grouped_images, grouped_images_index = group_images_by_shape(
[image_data], disable_grouping=self.disable_grouping
diff --git a/lightllm/models/qwen3_5/__init__.py b/lightllm/models/qwen3_5/__init__.py
new file mode 100644
index 0000000000..56a41a228a
--- /dev/null
+++ b/lightllm/models/qwen3_5/__init__.py
@@ -0,0 +1,16 @@
+"""
+Qwen3.5 Multimodal Model Module (Dense Variant)
+
+Provides Qwen3.5 dense multimodal model with hybrid attention and vision-language support.
+For MoE variant, see qwen3_5_moe module.
+"""
+
+from .model import (
+ Qwen3_5TpPartModel,
+ QWen3_5Tokenizer,
+)
+
+__all__ = [
+ "Qwen3_5TpPartModel",
+ "QWen3_5Tokenizer",
+]
diff --git a/lightllm/models/qwen3_5/infer_struct.py b/lightllm/models/qwen3_5/infer_struct.py
new file mode 100644
index 0000000000..d837c4d291
--- /dev/null
+++ b/lightllm/models/qwen3_5/infer_struct.py
@@ -0,0 +1,17 @@
+import torch
+from typing import List
+
+from lightllm.models.qwen2_vl.infer_struct import Qwen2VLInferStateInfo
+from lightllm.utils.envs_utils import get_env_start_args
+
+
+class Qwen35InferStateInfo(Qwen2VLInferStateInfo):
+ def __init__(self):
+ super().__init__()
+ self.gate_value = None
+
+ def init_some_extra_state(self, model):
+ super().init_some_extra_state(model)
+ self.b_att_seq_len = self.b_seq_len
+ self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous()
+ return
diff --git a/lightllm/models/qwen3_5/layer_infer/__init__.py b/lightllm/models/qwen3_5/layer_infer/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py
new file mode 100644
index 0000000000..d0657bcbe8
--- /dev/null
+++ b/lightllm/models/qwen3_5/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,58 @@
+import torch
+import torch.distributed as dist
+from typing import Tuple
+
+from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import (
+ Qwen3NextTransformerLayerInfer,
+)
+from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import (
+ Qwen35TransformerLayerWeight,
+)
+from lightllm.models.qwen2_vl.triton_kernel.mrope import mrope_triton_fused
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+class Qwen35TransformerLayerInfer(Qwen3NextTransformerLayerInfer):
+ def __init__(self, layer_num, network_config):
+ super().__init__(layer_num, network_config)
+ # Initialize mrope section from config
+ rope_scaling = network_config.get("rope_scaling", {})
+ mrope_section = rope_scaling.get("mrope_section", [11, 11, 10])
+ self.mrope_section = torch.tensor(mrope_section, dtype=torch.int32, device="cuda")
+
+ def _get_qkv(
+ self,
+ input: torch.Tensor,
+ infer_state: LlamaInferStateInfo,
+ layer_weight: Qwen35TransformerLayerWeight,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ input = input.view(-1, self.embed_dim_)
+
+ qkv_out = layer_weight.qkv_proj.mm(input)
+ q, cache_kv = qkv_out.split(
+ [self.tp_q_head_num_ * self.head_dim_, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_], dim=-1
+ )
+ o_gate = layer_weight._o_gate_proj.mm(input)
+
+ # In-place sigmoid for gate
+ infer_state.gate_value = o_gate.sigmoid_()
+ layer_weight.qk_norm_weight_(
+ q,
+ cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
+ eps=self.eps_,
+ )
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+
+ mrope_triton_fused(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ self.mrope_section,
+ is_interleaved=True, # Qwen3 uses interleaved mrope
+ partial_rotary_factor=self.partial_rotary_factor,
+ )
+ return q, cache_kv
diff --git a/lightllm/models/qwen3_5/layer_weights/__init__.py b/lightllm/models/qwen3_5/layer_weights/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py
new file mode 100644
index 0000000000..da93133444
--- /dev/null
+++ b/lightllm/models/qwen3_5/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,92 @@
+import torch
+
+from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight
+from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import (
+ Qwen3NextTransformerLayerWeight,
+)
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+class Qwen35TransformerLayerWeight(Qwen3NextTransformerLayerWeight):
+ def _init_weight_names(self):
+ super()._init_weight_names()
+ self._gate_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"
+ self._gate_bias_name = None
+ self._up_weight_name = f"model.layers.{self.layer_num_}.mlp.up_proj.weight"
+ self._up_bias_name = None
+ self._gate_up_weight_name = f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight"
+ self._gate_up_bias_name = None
+ self._down_weight_name = f"model.layers.{self.layer_num_}.mlp.down_proj.weight"
+ self._down_bias_name = None
+
+ def _init_gdn_weight(self):
+ # Initialize everything from parent first, then override only linear_in_proj.
+ super()._init_gdn_weight()
+
+ prefix = f"model.layers.{self.layer_num_}.linear_attn"
+ hidden_size = self.network_config_["hidden_size"]
+ qk_dim = self.linear_num_k_heads * self.linear_k_head_dim
+ v_dim = self.linear_num_v_heads * self.linear_v_head_dim
+
+ # NOTE: keep grouped layout directly (q, k, v, z, b, a).
+ self.linear_in_proj = ROWMMWeight(
+ in_dim=hidden_size,
+ out_dims=[
+ qk_dim,
+ qk_dim,
+ v_dim,
+ v_dim,
+ self.linear_num_v_heads,
+ self.linear_num_v_heads,
+ ],
+ weight_names=[
+ f"{prefix}.in_proj_q.weight",
+ f"{prefix}.in_proj_k.weight",
+ f"{prefix}.in_proj_v.weight",
+ f"{prefix}.in_proj_z.weight",
+ f"{prefix}.in_proj_b.weight",
+ f"{prefix}.in_proj_a.weight",
+ ],
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("in_proj_weight"),
+ )
+
+ def _preprocess_weight(self, weights):
+ # Keep parent conv1d preprocessing path.
+ linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight"
+ linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias"
+
+ if linear_conv1d_weight_name in weights:
+ weights[linear_conv1d_weight_name] = self._parse_linear_conv1d(
+ weights[linear_conv1d_weight_name].squeeze(1)
+ )
+ if linear_conv1d_bias_name in weights:
+ weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name])
+
+ self._split_linear_in_proj_qkv(weights)
+
+ def _split_linear_in_proj_qkv(self, weights):
+ prefix = f"model.layers.{self.layer_num_}.linear_attn"
+ qkv_name = f"{prefix}.in_proj_qkv.weight"
+ if qkv_name not in weights:
+ return
+
+ qk_dim = self.linear_num_k_heads * self.linear_k_head_dim
+ v_dim = self.linear_num_v_heads * self.linear_v_head_dim
+ expected_rows = 2 * qk_dim + v_dim
+
+ qkv = weights[qkv_name]
+ if qkv.shape[0] != expected_rows:
+ logger.warning(
+ f"Layer {self.layer_num_}: unexpected in_proj_qkv shape "
+ f"{tuple(qkv.shape)}, expected first dim {expected_rows}; skip split"
+ )
+ return
+
+ q, k, v = torch.split(qkv, [qk_dim, qk_dim, v_dim], dim=0)
+ weights[f"{prefix}.in_proj_q.weight"] = q
+ weights[f"{prefix}.in_proj_k.weight"] = k
+ weights[f"{prefix}.in_proj_v.weight"] = v
+ del weights[qkv_name]
diff --git a/lightllm/models/qwen3_5/model.py b/lightllm/models/qwen3_5/model.py
new file mode 100644
index 0000000000..63503c77ba
--- /dev/null
+++ b/lightllm/models/qwen3_5/model.py
@@ -0,0 +1,100 @@
+import os
+import json
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.qwen3next.model import Qwen3NextTpPartModel
+from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import (
+ Qwen35TransformerLayerWeight,
+)
+from lightllm.models.qwen3_vl.model import QWen3VLTokenizer
+from lightllm.models.qwen3_vl.layer_infer.pre_layer_infer import (
+ Qwen3VLMultimodalPreLayerInfer,
+)
+from lightllm.models.qwen3_vl.layer_weights.pre_and_post_layer_weight import (
+ Qwen3VLPreAndPostLayerWeight,
+)
+from lightllm.models.qwen3_5.layer_infer.transformer_layer_infer import (
+ Qwen35TransformerLayerInfer,
+)
+from lightllm.models.qwen3_5.infer_struct import Qwen35InferStateInfo
+from lightllm.common.build_utils import repair_config
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+class QWen3_5Tokenizer(QWen3VLTokenizer):
+ """
+ Tokenizer for Qwen3.5 multimodal model.
+
+ Inherits all multimodal tokenization logic from Qwen3VL,
+ including image and video token handling.
+ """
+
+ def __init__(self, tokenizer=None, image_processor=None, **kwargs):
+ super().__init__(tokenizer, image_processor, **kwargs)
+
+
+@ModelRegistry(["qwen3_5"], is_multimodal=True)
+class Qwen3_5TpPartModel(Qwen3NextTpPartModel):
+ """
+ Qwen3.5 Multimodal Model (Dense Variant)
+
+ This model combines:
+ - Hybrid attention from Qwen3Next (Gated Delta Networks + Full Attention)
+ - Multimodal capabilities from Qwen3VL (image/video processing)
+ - Dense MLP layers (non-MoE)
+
+ Architecture:
+ - Every Nth layer uses full attention (config: full_attention_interval)
+ - Other layers use linear attention (Gated Delta Networks)
+ - Vision encoder processes images/videos before text model
+ - Multimodal embeddings merged with text embeddings
+ """
+
+ transformer_weight_class = Qwen35TransformerLayerWeight
+ pre_and_post_weight_class = Qwen3VLPreAndPostLayerWeight
+
+ pre_layer_infer_class = Qwen3VLMultimodalPreLayerInfer
+ transformer_layer_infer_class = Qwen35TransformerLayerInfer
+
+ infer_state_class = Qwen35InferStateInfo
+
+ def _init_config(self):
+ config_path = os.path.join(self.weight_dir_, "config.json")
+
+ with open(config_path, "r") as json_file:
+ all_config = json.load(json_file)
+
+ self.config = all_config["text_config"]
+ self.vision_config = all_config.get("vision_config", None)
+
+ if self.vision_config is None:
+ logger.warning("No vision_config found in checkpoint. " "Multimodal features may not work correctly.")
+
+ # Apply standard config repairs
+ repair_config(self.config, same_names=["num_attention_heads", "n_head"])
+ repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"])
+ repair_config(self.config, same_names=["num_hidden_layers", "n_layer"])
+
+ # Qwen3.5 stores RoPE config under text_config.rope_parameters.
+ rope_parameters = self.config.get("rope_parameters")
+ if isinstance(rope_parameters, dict):
+ if "rope_theta" in rope_parameters and "rope_theta" not in self.config:
+ self.config["rope_theta"] = rope_parameters["rope_theta"]
+ if "partial_rotary_factor" in rope_parameters and "partial_rotary_factor" not in self.config:
+ self.config["partial_rotary_factor"] = rope_parameters["partial_rotary_factor"]
+ # Preserve the richer RoPE metadata in the expected field when absent.
+ if "rope_scaling" not in self.config:
+ self.config["rope_scaling"] = rope_parameters
+
+ # MoE routing parameters - set defaults for Qwen3.5 compatibility
+ if "norm_topk_prob" not in self.config:
+ self.config["norm_topk_prob"] = True # Standard default for MoE models
+
+ # Handle fine-tuning config if present
+ if self.finetune_config:
+ self.config["vocab_size"] = self.finetune_config.vocab_size
+
+ # Calculate num_kv_heads for KV cache memory management
+ # Required by parent class _init_mem_manager() in Qwen3NextTpPartModel
+ self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
diff --git a/lightllm/models/qwen3_5_moe/__init__.py b/lightllm/models/qwen3_5_moe/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/qwen3_5_moe/layer_weights/__init__.py b/lightllm/models/qwen3_5_moe/layer_weights/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py
new file mode 100644
index 0000000000..fe4b1883bd
--- /dev/null
+++ b/lightllm/models/qwen3_5_moe/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,40 @@
+import torch
+from lightllm.models.qwen3_5.layer_weights.transformer_layer_weight import Qwen35TransformerLayerWeight
+
+
+class Qwen35MOETransformerLayerWeight(Qwen35TransformerLayerWeight):
+ def load_hf_weights(self, weights):
+ moe_intermediate_size = self.network_config_["moe_intermediate_size"]
+ split_fused_expert_weights(weights, self.layer_num_, moe_intermediate_size)
+ return super().load_hf_weights(weights)
+
+
+def split_fused_expert_weights(weights: dict, layer_num: int, moe_intermediate_size: int):
+ layer_prefix = f"model.layers.{layer_num}."
+ keys = list(weights.keys())
+ num_experts = 0
+
+ for k in keys:
+ if not k.startswith(layer_prefix):
+ continue
+
+ if "mlp.experts.gate_up_proj" in k:
+ fused_weight = weights.pop(k) # [num_experts, 2*inter_size, hidden_size]
+ num_experts = fused_weight.shape[0]
+
+ prefix = k.rsplit(".gate_up_proj", 1)[0]
+ gate_weight = fused_weight[:, :moe_intermediate_size, :]
+ up_weight = fused_weight[:, moe_intermediate_size:, :]
+
+ for expert_idx in range(num_experts):
+ weights[f"{prefix}.{expert_idx}.gate_proj.weight"] = gate_weight[expert_idx]
+ weights[f"{prefix}.{expert_idx}.up_proj.weight"] = up_weight[expert_idx]
+
+ elif "mlp.experts.down_proj" in k:
+ down_weight = weights.pop(k) # [num_experts, hidden_size, inter_size]
+ num_experts = down_weight.shape[0]
+
+ prefix = k.rsplit(".down_proj", 1)[0]
+
+ for expert_idx in range(num_experts):
+ weights[f"{prefix}.{expert_idx}.down_proj.weight"] = down_weight[expert_idx]
diff --git a/lightllm/models/qwen3_5_moe/model.py b/lightllm/models/qwen3_5_moe/model.py
new file mode 100644
index 0000000000..973274774f
--- /dev/null
+++ b/lightllm/models/qwen3_5_moe/model.py
@@ -0,0 +1,12 @@
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.qwen3_5.model import Qwen3_5TpPartModel
+from lightllm.utils.log_utils import init_logger
+from lightllm.models.qwen3_5_moe.layer_weights.transformer_layer_weight import (
+ Qwen35MOETransformerLayerWeight,
+)
+
+
+@ModelRegistry("qwen3_5_moe", is_multimodal=True)
+class Qwen3_5MOETpPartModel(Qwen3_5TpPartModel):
+
+ transformer_weight_class = Qwen35MOETransformerLayerWeight
diff --git a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
index b85216f22c..721893a4cd 100644
--- a/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
+++ b/lightllm/models/qwen3_moe/layer_infer/transformer_layer_infer.py
@@ -22,14 +22,14 @@
class Qwen3MOETransformerLayerInfer(LlamaTransformerLayerInfer):
def __init__(self, layer_num, network_config):
- self.n_routed_experts = network_config["num_experts"]
+ self.n_routed_experts = network_config.get("num_experts", 0)
self.is_moe = (
- network_config["num_experts"] > 0
- and layer_num not in network_config["mlp_only_layers"]
- and (layer_num + 1) % network_config["decoder_sparse_step"] == 0
+ network_config.get("num_experts", 0) > 0
+ and layer_num not in network_config.get("mlp_only_layers", [])
+ and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0
)
- self.num_experts_per_tok = network_config["num_experts_per_tok"]
- self.norm_topk_prob = network_config["norm_topk_prob"]
+ self.num_experts_per_tok = network_config.get("num_experts_per_tok", 0)
+ self.norm_topk_prob = network_config.get("norm_topk_prob", True)
super().__init__(layer_num, network_config)
self.head_dim_ = network_config["head_dim"]
self.tp_k_head_num_ = max(self.tp_k_head_num_, 1)
diff --git a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
index 13ba6cbe0f..e525cb2d20 100644
--- a/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
+++ b/lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py
@@ -5,11 +5,11 @@
class Qwen3MOETransformerLayerWeight(Qwen3TransformerLayerWeight):
def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
- self.n_routed_experts = network_config["num_experts"]
+ self.n_routed_experts = network_config.get("num_experts", 0)
self.is_moe = (
- network_config["num_experts"] > 0
- and layer_num not in network_config["mlp_only_layers"]
- and (layer_num + 1) % network_config["decoder_sparse_step"] == 0
+ network_config.get("num_experts", 0) > 0
+ and layer_num not in network_config.get("mlp_only_layers", [])
+ and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0
)
super().__init__(layer_num, data_type, network_config, quant_cfg)
return
diff --git a/lightllm/models/qwen3_moe/model.py b/lightllm/models/qwen3_moe/model.py
index 10a5051276..b71d7f4878 100644
--- a/lightllm/models/qwen3_moe/model.py
+++ b/lightllm/models/qwen3_moe/model.py
@@ -25,4 +25,6 @@ def __init__(self, kvargs):
def _init_custom(self):
super()._init_custom()
- dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"])
+ # Only initialize DeepEP group for MoE models with num_experts
+ if "num_experts" in self.config and self.config["num_experts"] > 0:
+ dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"])
diff --git a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py
index c20c227996..0276724749 100644
--- a/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py
+++ b/lightllm/models/qwen3_omni_moe_thinker/qwen3_omni_visual.py
@@ -60,6 +60,8 @@ def __init__(
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+ # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup)
+ self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
diff --git a/lightllm/models/qwen3_vl/qwen3_visual.py b/lightllm/models/qwen3_vl/qwen3_visual.py
index d389c853d5..bed8898115 100644
--- a/lightllm/models/qwen3_vl/qwen3_visual.py
+++ b/lightllm/models/qwen3_vl/qwen3_visual.py
@@ -29,6 +29,9 @@
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.models.qwen2_vl.vision_process import resize_image, Qwen2VLImageProcessor
from lightllm.models.qwen2_vl.qwen2_visual import VisionRotaryEmbedding, VisionFlashAttention
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
class Qwen3VLVisionMLP(nn.Module):
@@ -60,6 +63,8 @@ def __init__(
kernel_size = [temporal_patch_size, patch_size, patch_size]
self.proj = nn.Conv3d(in_channels, embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
+ # Convert weight to channels_last_3d for cuDNN optimization (~10% extra speedup)
+ self.proj.weight.data = self.proj.weight.data.contiguous(memory_format=torch.channels_last_3d)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
target_dtype = self.proj.weight.dtype
@@ -377,6 +382,7 @@ def encode(self, images: List[ImageItem]):
image_data = read_shm(get_shm_name_data(img.uuid))
image_data = Image.open(BytesIO(image_data))
pixel_values, image_grid_thw = self.processor.preprocess(image_data)
+
img_tensors.append(pixel_values)
img_grids.append(image_grid_thw)
else:
diff --git a/lightllm/models/qwen3next/__init__.py b/lightllm/models/qwen3next/__init__.py
new file mode 100644
index 0000000000..a9d22c6643
--- /dev/null
+++ b/lightllm/models/qwen3next/__init__.py
@@ -0,0 +1,3 @@
+from lightllm.models.qwen3next.model import Qwen3NextTpPartModel
+
+__all__ = ["Qwen3NextTpPartModel"]
diff --git a/lightllm/models/qwen3next/infer_struct.py b/lightllm/models/qwen3next/infer_struct.py
new file mode 100644
index 0000000000..cd7c8d908d
--- /dev/null
+++ b/lightllm/models/qwen3next/infer_struct.py
@@ -0,0 +1,16 @@
+import torch
+from lightllm.models.llama.infer_struct import LlamaInferStateInfo
+from lightllm.utils.envs_utils import get_env_start_args
+
+
+class Qwen3NextInferStateInfo(LlamaInferStateInfo):
+ def __init__(self):
+ super().__init__()
+ self.gate_value = None
+
+ def init_some_extra_state(self, model):
+ super().init_some_extra_state(model)
+ self.b_att_seq_len = self.b_seq_len
+ self.b_buffer_idx = model.req_manager.req_to_buffer_index[self.b_req_idx, 0].contiguous()
+
+ return
diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py
new file mode 100644
index 0000000000..ec07b38c5a
--- /dev/null
+++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py
@@ -0,0 +1,381 @@
+import os
+import torch
+
+import torch.distributed as dist
+from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import (
+ Qwen3NextTransformerLayerWeight,
+)
+from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
+from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo
+from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl
+from lightllm.utils.log_utils import init_logger
+from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager
+from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
+from typing import Tuple
+from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward
+from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update
+from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating
+from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule
+from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule
+from lightllm.distributed import all_reduce
+from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
+from lightllm.utils.envs_utils import get_env_start_args, get_llm_data_type
+from functools import partial
+
+logger = init_logger(__name__)
+
+
+class Qwen3NextTransformerLayerInfer(LlamaTransformerLayerInfer):
+ def __init__(self, layer_num, network_config):
+ self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0)
+ self.n_routed_experts = network_config.get("num_experts", 0)
+ self.is_moe = (
+ network_config.get("num_experts", 0) > 0
+ and layer_num not in network_config.get("mlp_only_layers", [])
+ and (layer_num + 1) % network_config.get("decoder_sparse_step", 1) == 0
+ )
+ self.num_experts_per_tok = network_config.get("num_experts_per_tok", 1)
+ self.norm_topk_prob = network_config.get("norm_topk_prob", False)
+
+ super().__init__(layer_num, network_config)
+ self.head_dim_ = network_config.get(
+ "head_dim", network_config["hidden_size"] // network_config["num_attention_heads"]
+ )
+ num_full_attention_layers = network_config["full_attention_interval"]
+ self.is_linear_attention_layer = (layer_num + 1) % num_full_attention_layers != 0
+ if self.is_linear_attention_layer:
+ self._init_linear_layer_metadata(layer_num, network_config)
+ return
+
+ def _init_linear_layer_metadata(self, layer_num, network_config):
+
+ # Linear attention specific dimensions
+ self.num_v_heads = network_config["linear_num_value_heads"]
+ self.num_k_heads = network_config["linear_num_key_heads"]
+ self.head_k_dim = network_config["linear_key_head_dim"]
+ self.head_v_dim = network_config["linear_value_head_dim"]
+ self.key_dim = self.head_k_dim * self.num_k_heads
+ self.value_dim = self.head_v_dim * self.num_v_heads
+ self.conv_kernel_dim = network_config["linear_conv_kernel_dim"]
+ self.activation = network_config["hidden_act"]
+
+ # Tensor parallelism dimensions
+ self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_
+ self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_
+ self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_
+ self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_
+ self.tp_key_dim = self.key_dim // self.tp_world_size_
+ self.tp_value_dim = self.value_dim // self.tp_world_size_
+
+ assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads"
+ self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads
+
+ # SSM state dtype optimization
+ ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32}
+ start_args = get_env_start_args()
+ self.ssm_state_dtype = ssm_dtype_dict.get(start_args.mamba_ssm_data_type, torch.bfloat16)
+
+ # Pre-compute whether dtype conversion is needed
+ # GDN kernel output dtype is self.data_type
+ # Conversion needed only if SSM state uses different dtype
+ self.needs_ssm_dtype_conversion = get_llm_data_type() != self.ssm_state_dtype
+ return
+
+ def _bind_func(self):
+ super()._bind_func()
+ self._bind_ffn()
+ return
+
+ def _bind_ffn(self):
+ if self.is_moe:
+ moe_mode = os.environ.get("MOE_MODE", "TP")
+ if moe_mode == "EP":
+ self._ffn = partial(Qwen3NextTransformerLayerInfer._moe_ffn_edp, self)
+ else:
+ self._ffn = partial(Qwen3NextTransformerLayerInfer._moe_ffn, self)
+ else:
+ self._ffn = partial(Qwen3NextTransformerLayerInfer._ffn, self)
+ return
+
+ def _compute_shared_expert(
+ self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
+ ):
+ input = input.view(-1, self.embed_dim_)
+ shared_expert_out = super()._ffn(input, infer_state, layer_weight)
+ gate = layer_weight.ffn_gate.mm(input).sigmoid_()
+ shared_expert_out.mul_(gate)
+ return shared_expert_out
+
+ def _moe_ffn(
+ self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
+ ):
+
+ shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight)
+
+ hidden_states = input.view(-1, self.embed_dim_)
+ num_tokens, hidden_dim = hidden_states.shape
+ router_logits = layer_weight.moe_gate.mm(hidden_states)
+ layer_weight.experts.experts(
+ hidden_states,
+ router_logits=router_logits,
+ top_k=self.num_experts_per_tok,
+ renormalize=self.norm_topk_prob,
+ use_grouped_topk=False,
+ topk_group=None,
+ num_expert_group=None,
+ )
+ hidden_states = hidden_states.view(num_tokens, hidden_dim)
+ hidden_states.add_(shared_expert_out)
+ return hidden_states
+
+ def _moe_ffn_edp(
+ self, input: torch.Tensor, infer_state: Qwen3NextInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight
+ ):
+ shared_expert_out = self._compute_shared_expert(input, infer_state, layer_weight)
+ hidden_states = input
+ token_num, hidden_dim = hidden_states.shape
+ router_logits = layer_weight.moe_gate.mm(hidden_states)
+ ep_output = layer_weight.experts.experts(
+ hidden_states,
+ router_logits=router_logits,
+ top_k=self.num_experts_per_tok,
+ renormalize=self.norm_topk_prob,
+ use_grouped_topk=False,
+ topk_group=None,
+ num_expert_group=None,
+ is_prefill=infer_state.is_prefill,
+ )
+ ep_output = ep_output.view(token_num, hidden_dim)
+ ep_output.add_(shared_expert_out)
+ return ep_output
+
+ def _get_qkv(
+ self,
+ input: torch.Tensor,
+ infer_state: Qwen3NextInferStateInfo,
+ layer_weight: Qwen3NextTransformerLayerWeight,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ input = input.view(-1, self.embed_dim_)
+ qkv_out = layer_weight.qkv_proj.mm(input)
+ q, cache_kv = qkv_out.split(
+ [self.tp_q_head_num_ * self.head_dim_ * 2, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_],
+ dim=-1,
+ )
+ o_gate = layer_weight._o_gate_proj.mm(input)
+ # In-place sigmoid saves one allocation (gate_value is consumed once in _get_o)
+ infer_state.gate_value = o_gate.sigmoid_()
+ layer_weight.qk_norm_weight_(
+ q,
+ cache_kv[:, : self.tp_k_head_num_ * self.head_dim_],
+ eps=self.eps_,
+ )
+ cache_kv = cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
+ rotary_emb_fwd(
+ q.view(-1, self.tp_q_head_num_, self.head_dim_),
+ cache_kv[:, : self.tp_k_head_num_, :],
+ infer_state.position_cos,
+ infer_state.position_sin,
+ partial_rotary_factor=self.partial_rotary_factor,
+ )
+ return q, cache_kv
+
+ def _get_o(
+ self,
+ input,
+ infer_state: Qwen3NextInferStateInfo,
+ layer_weight: Qwen3NextTransformerLayerWeight,
+ ) -> torch.Tensor:
+ """Output projection with gating (in-place multiply to save one allocation)."""
+ input = input.view(-1, self.tp_o_head_num_ * self.head_dim_)
+ input.mul_(infer_state.gate_value)
+ infer_state.gate_value = None
+ o_tensor = layer_weight.o_proj.mm(input)
+ return o_tensor
+
+ # ==================== GDN Helper Methods ====================
+
+ def context_attention_forward(
+ self,
+ input_embdings,
+ infer_state: Qwen3NextInferStateInfo,
+ layer_weight: Qwen3NextTransformerLayerWeight,
+ ):
+ if not self.is_linear_attention_layer:
+ return super().context_attention_forward(input_embdings, infer_state, layer_weight)
+
+ gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=True)
+ if self.tp_world_size_ > 1:
+ all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ return gdn_out
+
+ def token_attention_forward(
+ self,
+ input_embdings,
+ infer_state: Qwen3NextInferStateInfo,
+ layer_weight: Qwen3NextTransformerLayerWeight,
+ ):
+ if not self.is_linear_attention_layer:
+ return super().token_attention_forward(input_embdings, infer_state, layer_weight)
+ gdn_out = self.gdn_forward(input_embdings, infer_state, layer_weight, is_prefill=False)
+ if self.tp_world_size_ > 1:
+ all_reduce(gdn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
+ return gdn_out
+
+ def gdn_forward(
+ self,
+ input: torch.Tensor,
+ infer_state: Qwen3NextInferStateInfo,
+ layer_weight: Qwen3NextTransformerLayerWeight,
+ is_prefill: bool,
+ ):
+ assert isinstance(infer_state.mem_manager, Qwen3NextHybridMemManager)
+
+ input = input.view(-1, self.embed_dim_)
+ conv_states, ssm_states = infer_state.mem_manager.get_mamba_cache(self.layer_num_)
+
+ mixed_qkvzba = layer_weight.linear_in_proj.mm(input)
+ mixed_qkv, z, b, a = self._split_qkvzba(mixed_qkvzba, is_decode=not is_prefill)
+
+ if is_prefill:
+ g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight)
+ core_attn_out = self._gdn_prefill_kernel(
+ mixed_qkv, conv_states, ssm_states, g, beta, infer_state, layer_weight
+ )
+ else:
+ core_attn_out = self._gdn_decode_kernel(mixed_qkv, conv_states, ssm_states, a, b, infer_state, layer_weight)
+
+ num_tokens = z.shape[0]
+ core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1])
+ z = z.reshape(-1, z.shape[-1])
+ norm_out = self.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device)
+ gated_rmsnorm_forward(
+ core_attn_out,
+ layer_weight.linear_norm.weight,
+ None,
+ self.eps_,
+ z,
+ out=norm_out,
+ )
+ core_attn_out = norm_out.view(num_tokens, -1)
+ output = layer_weight.linear_out_proj.mm(core_attn_out)
+ return output
+
+ def _split_qkvzba(self, mixed_qkvzba, is_decode=False):
+ qkv_dim = self.tp_key_dim * 2 + self.tp_value_dim
+ z_end = qkv_dim + self.tp_value_dim
+ b_end = z_end + self.tp_num_v_heads
+ mixed_qkv = mixed_qkvzba[:, :qkv_dim]
+ z = mixed_qkvzba[:, qkv_dim:z_end].view(-1, self.tp_num_v_heads, self.head_v_dim)
+ b = mixed_qkvzba[:, z_end:b_end]
+ a = mixed_qkvzba[:, b_end:]
+ return mixed_qkv, z, b, a
+
+ def _rearrange_mixed_qkv(self, mixed_qkv, decode=False):
+ if mixed_qkv is None:
+ return None, None, None
+ if decode:
+ query, key, value = torch.split(
+ mixed_qkv,
+ [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim],
+ dim=-1,
+ )
+ batch_size = mixed_qkv.shape[0]
+ query = query.view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim)
+ key = key.view(batch_size, 1, self.tp_num_k_heads, self.head_k_dim)
+ value = value.view(batch_size, 1, self.tp_num_v_heads, self.head_v_dim)
+ return query, key, value
+ else:
+ query, key, value = torch.split(
+ mixed_qkv,
+ [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim],
+ dim=-1,
+ )
+ seq_len = query.shape[0]
+ query = query.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim)
+ key = key.view(1, seq_len, self.tp_num_k_heads, self.head_k_dim)
+ value = value.view(1, seq_len, self.tp_num_v_heads, self.head_v_dim)
+ return query, key, value
+
+ def _gdn_prefill_kernel(
+ self,
+ mixed_qkv: torch.Tensor,
+ conv_states: torch.Tensor,
+ ssm_states: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ infer_state: Qwen3NextInferStateInfo,
+ layer_weight: Qwen3NextTransformerLayerWeight,
+ ):
+ """Prefill kernel for GDN forward pass."""
+ # Conv1D processing
+ mixed_qkv = mixed_qkv.transpose(0, 1)
+ out_tensor = causal_conv1d_fn(
+ mixed_qkv,
+ layer_weight.linear_conv1d.mm_param.weight,
+ bias=layer_weight.linear_conv1d.bias,
+ query_start_loc=infer_state.b1_cu_q_seq_len,
+ cache_indices=infer_state.b_buffer_idx,
+ has_initial_state=infer_state.b_ready_cache_len > 0,
+ conv_states=conv_states,
+ activation=self.activation,
+ )
+ mixed_qkv = out_tensor.transpose(0, 1)
+
+ # Recurrent processing
+ query, key, value = self._rearrange_mixed_qkv(mixed_qkv)
+ initial_state = ssm_states[infer_state.b_buffer_idx]
+ # g and beta have shape (total_tokens, num_heads), need to unsqueeze to get (1, total_tokens, num_heads)
+ core_attn_out, last_recurrent_state = chunk_gated_delta_rule(
+ q=query,
+ k=key,
+ v=value,
+ g=g.unsqueeze(0),
+ beta=beta.unsqueeze(0),
+ initial_state=initial_state,
+ output_final_state=True,
+ cu_seqlens=infer_state.b1_cu_q_seq_len,
+ head_first=False,
+ use_qk_l2norm_in_kernel=True,
+ )
+ if self.needs_ssm_dtype_conversion:
+ ssm_states[infer_state.b_buffer_idx] = last_recurrent_state.to(self.ssm_state_dtype, copy=False)
+ else:
+ ssm_states[infer_state.b_buffer_idx] = last_recurrent_state
+ return core_attn_out
+
+ def _gdn_decode_kernel(
+ self,
+ mixed_qkv: torch.Tensor,
+ conv_states: torch.Tensor,
+ ssm_states: torch.Tensor,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ infer_state: Qwen3NextInferStateInfo,
+ layer_weight: Qwen3NextTransformerLayerWeight,
+ ):
+ mixed_qkv = causal_conv1d_update(
+ mixed_qkv,
+ conv_states,
+ layer_weight.linear_conv1d.mm_param.weight,
+ bias=layer_weight.linear_conv1d.bias,
+ activation=self.activation,
+ conv_state_indices=infer_state.b_buffer_idx,
+ )
+
+ # Recurrent processing with fused gating
+ # FusedRecurrentFunction.forward calls .contiguous() on q/k/v/a/b internally
+ query, key, value = self._rearrange_mixed_qkv(mixed_qkv, decode=True)
+ core_attn_out, _ = fused_recurrent_gated_delta_rule(
+ q=query,
+ k=key,
+ v=value,
+ initial_state=ssm_states,
+ inplace_final_state=True,
+ ssm_state_indices=infer_state.b_buffer_idx,
+ use_qk_l2norm_in_kernel=True,
+ A_log=layer_weight.linear_A_log.weight,
+ dt_bias=layer_weight.linear_dt_bias.weight,
+ a_raw=a,
+ b_raw=b,
+ )
+ return core_attn_out
diff --git a/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py
new file mode 100644
index 0000000000..daaf146907
--- /dev/null
+++ b/lightllm/models/qwen3next/layer_weights/pre_and_post_layer_weight.py
@@ -0,0 +1,29 @@
+from lightllm.common.basemodel import PreAndPostLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import EmbeddingWeight, LMHeadWeight, GEMMANormWeight
+
+
+class Qwen3NextPreAndPostLayerWeight(PreAndPostLayerWeight):
+ def __init__(self, data_type, network_config):
+ super().__init__(data_type, network_config)
+ hidden_size = network_config["hidden_size"]
+ vocab_size = network_config["vocab_size"]
+ self.wte_weight_ = EmbeddingWeight(
+ dim=hidden_size,
+ vocab_size=vocab_size,
+ weight_name="model.embed_tokens.weight",
+ data_type=self.data_type_,
+ )
+ tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
+ self.lm_head_weight_ = LMHeadWeight(
+ dim=hidden_size,
+ vocab_size=vocab_size,
+ weight_name="lm_head.weight",
+ data_type=self.data_type_,
+ embedding_weight=self.wte_weight_ if tie_word_embeddings else None,
+ )
+ self.final_norm_weight_ = GEMMANormWeight(
+ dim=hidden_size,
+ weight_name="model.norm.weight",
+ data_type=self.data_type_,
+ )
+ return
diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py
new file mode 100644
index 0000000000..31dae85ec8
--- /dev/null
+++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py
@@ -0,0 +1,273 @@
+import torch
+from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
+from lightllm.common.basemodel.layer_weights.meta_weights import (
+ ROWMMWeight,
+ COLMMWeight,
+ RMSNormWeight,
+ GEMMANormWeight,
+ TpParameterWeight,
+ QKVROWNMMWeight,
+ QKGEMMANormWeight,
+)
+
+
+class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight):
+ def __init__(self, layer_num, data_type, network_config, quant_cfg=None):
+ num_full_attention_layers = network_config["full_attention_interval"]
+ self.is_linear_attention_layer = (layer_num + 1) % num_full_attention_layers != 0
+ super().__init__(layer_num, data_type, network_config, quant_cfg)
+ return
+
+ def _init_qkv(self):
+ in_dim = self.n_embed
+ q_out_dim = self.q_head_num_ * self.head_dim
+ self.qkv_proj = QKVROWNMMWeight(
+ in_dim=in_dim,
+ q_head_num=self.q_head_num_,
+ kv_head_num=self.k_head_num_,
+ head_dim=self.head_dim,
+ weight_names=[self._q_weight_name, self._k_weight_name, self._v_weight_name],
+ data_type=self.data_type_,
+ bias_names=[self._q_bias_name, self._k_bias_name, self._v_bias_name],
+ quant_method=self.get_quant_method("qkv_proj"),
+ )
+ self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight"
+ self._o_gate_proj = ROWMMWeight(
+ in_dim=in_dim,
+ out_dims=[q_out_dim],
+ weight_names=[self._o_gate_weight_name],
+ data_type=self.data_type_,
+ bias_names=None,
+ quant_method=self.get_quant_method("o_gate_proj"),
+ )
+
+ def _init_weight(self):
+ if self.is_linear_attention_layer:
+ self._init_gdn_weight()
+ else:
+ self._init_qkv()
+ self._init_o()
+
+ if self.is_moe:
+ self._init_moe()
+ else:
+ self._init_ffn()
+ self._init_norm()
+
+ def _init_moe(self):
+ super()._init_moe()
+ self._init_gated_ffn()
+ return
+
+ def _init_norm(self):
+ hidden_size = self.network_config_["hidden_size"]
+ self.att_norm_weight_ = GEMMANormWeight(
+ dim=hidden_size,
+ weight_name=self._att_norm_weight_name,
+ data_type=self.data_type_,
+ )
+ self.ffn_norm_weight_ = GEMMANormWeight(
+ dim=hidden_size,
+ weight_name=self._ffn_norm_weight_name,
+ data_type=self.data_type_,
+ )
+ if not self.is_linear_attention_layer:
+ self.qk_norm_weight_ = QKGEMMANormWeight(
+ dim=self.head_dim,
+ q_weight_name=self._q_norm_name,
+ k_weight_name=self._k_norm_name,
+ data_type=self.data_type_,
+ )
+
+ def _init_gated_ffn(self):
+ hidden_size = self.network_config_["hidden_size"]
+ if "shared_expert_intermediate_size" not in self.network_config_:
+ return
+ prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert"
+ inter_size = self.network_config_["shared_expert_intermediate_size"]
+ self.gate_up_proj = ROWMMWeight(
+ in_dim=hidden_size,
+ out_dims=[inter_size, inter_size],
+ weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"],
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("gate_up_proj"),
+ )
+ self.down_proj = COLMMWeight(
+ in_dim=inter_size,
+ out_dims=[hidden_size],
+ weight_names=f"{prefix}.down_proj.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("down_proj"),
+ )
+ self.ffn_gate = ROWMMWeight(
+ in_dim=hidden_size,
+ out_dims=[1],
+ weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight",
+ data_type=self.data_type_,
+ bias_names=None,
+ quant_method=None,
+ tp_rank=0,
+ tp_world_size=1,
+ )
+
+ def _split_q_with_gate(self, weights):
+ if self._q_weight_name in weights:
+ weight = weights[self._q_weight_name]
+ num_heads = self.q_head_num_
+ weight = weight.view(num_heads * 2, self.head_dim, -1)
+ _q_proj = weight[0::2].reshape(-1, weight.shape[-1])
+ _gate_proj = weight[1::2].reshape(-1, weight.shape[-1])
+ weights[self._q_weight_name] = _q_proj
+ weights[self._o_gate_weight_name] = _gate_proj
+
+ def _parse_config(self):
+ super()._parse_config()
+ self.linear_num_v_heads = self.network_config_["linear_num_value_heads"]
+ self.linear_num_k_heads = self.network_config_["linear_num_key_heads"]
+ self.linear_k_head_dim = self.network_config_["linear_key_head_dim"]
+ self.linear_v_head_dim = self.network_config_["linear_value_head_dim"]
+
+ def _init_gdn_weight(self):
+ prefix = f"model.layers.{self.layer_num_}.linear_attn"
+ hidden_size = self.network_config_["hidden_size"]
+ qk_dim = self.linear_num_k_heads * self.linear_k_head_dim
+ v_dim = self.linear_num_v_heads * self.linear_v_head_dim
+ conv1d_channels = qk_dim + qk_dim + v_dim # q + k + v concatenated
+ kernel_size = self.network_config_.get("linear_conv_kernel_dim", 4)
+
+ # Conv1d weight: after _preprocess_weight, shape is [channels, kernel_size].
+ self.linear_conv1d = ROWMMWeight(
+ in_dim=kernel_size,
+ out_dims=[conv1d_channels],
+ weight_names=f"{prefix}.conv1d.weight",
+ data_type=self.data_type_,
+ quant_method=None,
+ )
+
+ # in_proj_qkvz: q(qk_dim) + k(qk_dim) + v(v_dim) + z(v_dim)
+ # in_proj_ba: beta(num_v_heads) + alpha(num_v_heads) — per-head scalars
+ qkvz_dim = qk_dim + qk_dim + v_dim + v_dim
+ ba_dim = self.linear_num_v_heads + self.linear_num_v_heads
+ self.linear_in_proj = ROWMMWeight(
+ in_dim=hidden_size,
+ out_dims=[qkvz_dim, ba_dim],
+ weight_names=[f"{prefix}.in_proj_qkvz.weight", f"{prefix}.in_proj_ba.weight"],
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("in_proj_weight"),
+ )
+
+ self.linear_out_proj = COLMMWeight(
+ in_dim=v_dim,
+ out_dims=[hidden_size],
+ weight_names=f"{prefix}.out_proj.weight",
+ data_type=self.data_type_,
+ quant_method=self.get_quant_method("out_proj_weight"),
+ )
+
+ split_n_embed = self.linear_num_v_heads // self.tp_world_size_
+ self.linear_dt_bias = TpParameterWeight(
+ weight_name=f"{prefix}.dt_bias",
+ data_type=torch.float32,
+ split_n_embed=split_n_embed,
+ bias_name=None,
+ weight_shape=(self.linear_num_v_heads,), # Full shape before TP split
+ bias_shape=None,
+ )
+
+ self.linear_A_log = TpParameterWeight(
+ weight_name=f"{prefix}.A_log",
+ data_type=torch.float32,
+ split_n_embed=split_n_embed,
+ bias_name=None,
+ weight_shape=(self.linear_num_v_heads,), # Full shape before TP split
+ bias_shape=None,
+ )
+
+ # Norm is applied per-head across head_dim, not across all heads
+ linear_norm_dim = self.linear_v_head_dim
+ self.linear_norm = RMSNormWeight(
+ dim=linear_norm_dim,
+ weight_name=f"{prefix}.norm.weight",
+ data_type=self.data_type_,
+ )
+
+ def _preprocess_weight(self, weights):
+ linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight"
+ linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias"
+ if linear_conv1d_weight_name in weights:
+ # squeeze [channels, 1, kernel] -> [channels, kernel], then rearrange for TP
+ # Result shape: [channels, kernel_size] — matches causal_conv1d_fn's (dim, width)
+ weights[linear_conv1d_weight_name] = self._parse_linear_conv1d(
+ weights[linear_conv1d_weight_name].squeeze(1)
+ )
+ if linear_conv1d_bias_name in weights:
+ weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name])
+ self._rearrange_gdn_in_proj_weights(weights)
+
+ def _rearrange_gdn_in_proj_weights(self, weights):
+ """Rearrange in_proj_qkvz and in_proj_ba weight rows from interleaved per-k-head layout
+ to TP-aware grouped layout so that after ROWMMWeight's row-slicing, each rank's
+ MM output is already [q_chunk, k_chunk, v_chunk, z_chunk, b_chunk, a_chunk].
+ """
+ num_k = self.linear_num_k_heads
+ k_dim = self.linear_k_head_dim
+ v_dim = self.linear_v_head_dim
+ num_v_per_k = self.linear_num_v_heads // num_k
+ tp = self.tp_world_size_
+
+ # Rearrange in_proj_qkvz
+ qkvz_name = f"model.layers.{self.layer_num_}.linear_attn.in_proj_qkvz.weight"
+ if qkvz_name in weights:
+ w = weights[qkvz_name]
+ hidden = w.shape[-1]
+ # Each k-head group: q(k_dim) + k(k_dim) + v(num_v_per_k * v_dim) + z(num_v_per_k * v_dim) rows
+ group_size = k_dim + k_dim + num_v_per_k * v_dim + num_v_per_k * v_dim
+ w = w.view(num_k, group_size, hidden)
+ v_block = num_v_per_k * v_dim
+ all_q = w[:, :k_dim, :].reshape(-1, hidden) # [total_q_dim, H]
+ all_k = w[:, k_dim : 2 * k_dim, :].reshape(-1, hidden) # [total_k_dim, H]
+ all_v = w[:, 2 * k_dim : 2 * k_dim + v_block, :].reshape(-1, hidden) # [total_v_dim, H]
+ all_z = w[:, 2 * k_dim + v_block :, :].reshape(-1, hidden) # [total_v_dim, H]
+ # Chunk each component by TP, interleave so row-slicing gives grouped layout per rank
+ q_chunks = all_q.chunk(tp, dim=0)
+ k_chunks = all_k.chunk(tp, dim=0)
+ v_chunks = all_v.chunk(tp, dim=0)
+ z_chunks = all_z.chunk(tp, dim=0)
+ weights[qkvz_name] = torch.cat(
+ [torch.cat([q_chunks[i], k_chunks[i], v_chunks[i], z_chunks[i]], dim=0) for i in range(tp)],
+ dim=0,
+ )
+
+ # Rearrange in_proj_ba
+ ba_name = f"model.layers.{self.layer_num_}.linear_attn.in_proj_ba.weight"
+ if ba_name in weights:
+ w = weights[ba_name]
+ hidden = w.shape[-1]
+ group_size = 2 * num_v_per_k
+ w = w.view(num_k, group_size, hidden)
+ all_b = w[:, :num_v_per_k, :].reshape(-1, hidden) # [total_num_v, H]
+ all_a = w[:, num_v_per_k:, :].reshape(-1, hidden) # [total_num_v, H]
+ b_chunks = all_b.chunk(tp, dim=0)
+ a_chunks = all_a.chunk(tp, dim=0)
+ weights[ba_name] = torch.cat(
+ [torch.cat([b_chunks[i], a_chunks[i]], dim=0) for i in range(tp)],
+ dim=0,
+ )
+
+ def _parse_linear_conv1d(self, weight):
+ qk_dim = self.linear_num_k_heads * self.linear_k_head_dim
+ v_dim = self.linear_num_v_heads * self.linear_v_head_dim
+ q, k, v = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0)
+ q_splits = q.chunk(self.tp_world_size_, dim=0)
+ k_splits = k.chunk(self.tp_world_size_, dim=0)
+ v_splits = v.chunk(self.tp_world_size_, dim=0)
+ new_weight = torch.cat(
+ [torch.cat([q_splits[i], k_splits[i], v_splits[i]], dim=0) for i in range(self.tp_world_size_)], dim=0
+ )
+ return new_weight
+
+ def load_hf_weights(self, weights):
+ self._split_q_with_gate(weights)
+ if self.is_linear_attention_layer:
+ self._preprocess_weight(weights)
+ super().load_hf_weights(weights)
diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py
new file mode 100644
index 0000000000..5c8d486edf
--- /dev/null
+++ b/lightllm/models/qwen3next/mem_manager.py
@@ -0,0 +1,154 @@
+import torch
+from typing import Tuple
+from lightllm.utils.log_utils import init_logger
+from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager
+from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager
+from lightllm.server.core.objs.start_args_type import StartArgs
+
+logger = init_logger(__name__)
+
+
+class Qwen3NextHybridMemManager(MemoryManager):
+ @staticmethod
+ def calculate_mamba_cache_size(
+ start_args: StartArgs,
+ max_total_token_num: int,
+ mem_fraction: float,
+ config: dict,
+ head_linear_k_dim: int,
+ num_linear_k_heads: int,
+ head_linear_v_dim: int,
+ num_linear_v_heads: int,
+ tp_world_size: int,
+ data_type: torch.dtype,
+ ) -> int:
+ """Calculate mamba cache size based on available memory and mamba_cache_ratio."""
+ from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
+ import torch.distributed as dist
+
+ use_ratio = max_total_token_num is None and start_args.mamba_cache_size is None
+
+ world_size = dist.get_world_size()
+ total_memory = get_total_gpu_memory()
+ available_memory = get_available_gpu_memory(world_size) - total_memory * (1 - mem_fraction)
+
+ conv_kernel_size = config["linear_conv_kernel_dim"]
+ conv_dim = (
+ head_linear_k_dim * num_linear_k_heads * 2 + head_linear_v_dim * num_linear_v_heads
+ ) // tp_world_size
+
+ num_linear_layers = config["n_layer"] - (config["n_layer"] // config["full_attention_interval"])
+
+ conv_cell_size = num_linear_layers * conv_dim * (conv_kernel_size - 1) * torch._utils._element_size(data_type)
+
+ ssm_dtype = torch.bfloat16 if start_args.mamba_ssm_data_type == "bfloat16" else torch.float32
+ ssm_cell_size = (
+ num_linear_layers
+ * (num_linear_v_heads // tp_world_size)
+ * head_linear_k_dim
+ * head_linear_v_dim
+ * torch._utils._element_size(ssm_dtype)
+ )
+
+ total_cell_size = conv_cell_size + ssm_cell_size
+
+ if use_ratio:
+ # mamba_cache_ratio = mamba_memory / total_cache_memory
+ mamba_cache_ratio = start_args.mamba_cache_ratio if start_args.mamba_cache_ratio is not None else 0.5
+ mamba_memory_gb = available_memory * mamba_cache_ratio
+ else:
+ mamba_memory_gb = available_memory
+ mamba_cache_ratio = None
+
+ mamba_cache_size = int(mamba_memory_gb * 1024 ** 3 / total_cell_size)
+
+ if mamba_cache_size < start_args.running_max_req_size * 2:
+ ratio = mamba_cache_ratio if mamba_cache_ratio is not None else 0.5
+ raise ValueError(
+ f"Insufficient memory for mamba cache allocation!\n\n"
+ f"mamba_cache_size should be at least running_max_req_size * 2\n"
+ f"Calculated mamba_cache_size ({mamba_cache_size}) < "
+ f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n\n"
+ f"Memory budget:\n"
+ f" Available for mamba cache: {mamba_memory_gb:.2f} GB\n"
+ f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n"
+ f" Calculated buffers: {mamba_cache_size}\n"
+ f" Required buffers: {start_args.running_max_req_size}\n\n"
+ f"Solutions:\n"
+ f" 1. Reduce --running_max_req_size to {mamba_cache_size} or lower\n"
+ f" 2. Increase --mamba_cache_ratio from {ratio} to "
+ f"{start_args.running_max_req_size / mamba_cache_size * ratio:.3f} or higher\n"
+ f" 3. Increase --mem_fraction to leave more memory for caches\n"
+ )
+
+ logger.info(
+ f"Mamba cache allocation:\n"
+ f" Available memory: {mamba_memory_gb:.2f} GB\n"
+ f" Memory per buffer: {total_cell_size / 1024 ** 2:.2f} MB\n"
+ f" Calculated mamba_cache_size: {mamba_cache_size}"
+ )
+
+ return mamba_cache_size
+
+ def __init__(
+ self,
+ full_attn_cache_size,
+ linear_attn_cache_size,
+ dtype,
+ num_kv_heads,
+ head_dim,
+ layer_num,
+ mtp_layer_num,
+ full_attention_interval: int,
+ conv_state_dtype: torch.dtype,
+ conv_state_shape: Tuple[int, ...],
+ ssm_state_dtype: torch.dtype,
+ ssm_state_shape: Tuple[int, ...],
+ max_req_num: int,
+ always_copy=False,
+ mem_fraction=0.9,
+ ):
+
+ self.full_attention_interval = full_attention_interval
+ assert layer_num % full_attention_interval == 0
+ self.layer_num = layer_num
+ self.mtp_layer_num = mtp_layer_num
+ self.full_attn_layer_num = layer_num // full_attention_interval
+ self.linear_attn_layer_num = layer_num - self.full_attn_layer_num
+
+ self.mamba_cache_mem_manager = MambaCacheManager(
+ linear_attn_cache_size,
+ self.linear_attn_layer_num,
+ conv_state_dtype,
+ conv_state_shape,
+ ssm_state_dtype,
+ ssm_state_shape,
+ )
+
+ super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction)
+
+ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
+ # KV buffer layout: [None, None, None, kv_cache, None, None, None, kv_cache, ...,
+ # None, kv_cache, mtp_kv_cache, mtp_kv_cache]
+ # Only full attention layers and MTP layers have KV cache.
+ self.kv_buffer = [None for _ in range(self.layer_num)]
+ for layer_id in range(self.full_attn_layer_num):
+ self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty(
+ (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda"
+ )
+ for _ in range(self.mtp_layer_num):
+ self.kv_buffer.append(torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda"))
+
+ def free_all(self):
+ super().free_all()
+ self.mamba_cache_mem_manager.free_all()
+ return
+
+ def get_cell_size(self):
+ # Only full attention layers and MTP layers have KV cache
+ kv_cache_layer_num = self.full_attn_layer_num + self.mtp_layer_num
+ return 2 * self.head_num * self.head_dim * kv_cache_layer_num * torch._utils._element_size(self.dtype)
+
+ def get_mamba_cache(self, layer_idx: int):
+ layer_idx_in_linear = layer_idx - (layer_idx // self.full_attention_interval)
+ return self.mamba_cache_mem_manager.get_mamba_cache(layer_idx_in_linear)
diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py
new file mode 100644
index 0000000000..b00f57f3ec
--- /dev/null
+++ b/lightllm/models/qwen3next/model.py
@@ -0,0 +1,139 @@
+import torch
+from typing import Optional
+import triton
+from lightllm.models.registry import ModelRegistry
+from lightllm.models.qwen3_moe.model import Qwen3MOEModel
+from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import (
+ Qwen3NextTransformerLayerWeight,
+)
+from lightllm.models.qwen3next.layer_weights.pre_and_post_layer_weight import Qwen3NextPreAndPostLayerWeight
+from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import (
+ Qwen3NextTransformerLayerInfer,
+)
+from lightllm.models.qwen3next.infer_struct import Qwen3NextInferStateInfo
+from lightllm.utils.log_utils import init_logger
+from lightllm.distributed.communication_op import dist_group_manager
+from lightllm.utils.envs_utils import get_env_start_args
+from lightllm.models.qwen3next.mem_manager import Qwen3NextHybridMemManager
+from lightllm.server.core.objs.start_args_type import StartArgs
+from lightllm.common.req_manager import ReqManagerForMamba
+from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache
+
+logger = init_logger(__name__)
+
+
+@ModelRegistry("qwen3_next")
+class Qwen3NextTpPartModel(Qwen3MOEModel):
+
+ # weight class
+ pre_and_post_weight_class = Qwen3NextPreAndPostLayerWeight
+ transformer_weight_class = Qwen3NextTransformerLayerWeight
+
+ # infer class
+ transformer_layer_infer_class = Qwen3NextTransformerLayerInfer
+
+ # infer state class
+ infer_state_class = Qwen3NextInferStateInfo
+
+ def get_radix_class(self):
+ return HybridRadixCache
+
+ def __init__(self, kvargs) -> None:
+ self.mem_manager: Qwen3NextHybridMemManager = None
+
+ def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch.Tensor:
+ return torch.empty(size, device="cuda", dtype=torch.int8)
+
+ # Set Triton allocator for TMA descriptors
+ # This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py
+ triton.set_allocator(_triton_allocator)
+ logger.info("Triton allocator set for Qwen3Next model")
+ super().__init__(kvargs)
+
+ def autotune_layers(self):
+ return self.config["full_attention_interval"]
+
+ def _init_config(self):
+ super()._init_config()
+ self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
+
+ def _init_custom(self):
+ super()._init_custom()
+ # Only initialize DeepEP group for MoE models with num_experts
+ if "num_experts" in self.config and self.config["num_experts"] > 0:
+ dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"])
+
+ def _init_mem_manager(self):
+ assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
+
+ start_args: StartArgs = get_env_start_args()
+ mamba_cache_size = start_args.mamba_cache_size
+
+ self.num_linear_k_heads = self.config["linear_num_key_heads"]
+ self.num_linear_v_heads = self.config["linear_num_value_heads"]
+ self.head_linear_k_dim = self.config["linear_key_head_dim"]
+ self.head_linear_v_dim = self.config["linear_value_head_dim"]
+
+ if mamba_cache_size is None:
+ mamba_cache_size = Qwen3NextHybridMemManager.calculate_mamba_cache_size(
+ start_args=start_args,
+ max_total_token_num=self.max_total_token_num,
+ mem_fraction=self.mem_fraction,
+ config=self.config,
+ head_linear_k_dim=self.head_linear_k_dim,
+ num_linear_k_heads=self.num_linear_k_heads,
+ head_linear_v_dim=self.head_linear_v_dim,
+ num_linear_v_heads=self.num_linear_v_heads,
+ tp_world_size=self.tp_world_size_,
+ data_type=self.data_type,
+ )
+ else:
+ if mamba_cache_size < start_args.running_max_req_size * 2:
+ raise ValueError(
+ f"Explicitly set mamba_cache_size ({mamba_cache_size}) < "
+ f"running_max_req_size * 2 ({start_args.running_max_req_size * 2})\n"
+ f"Please increase mamba_cache_size to at least {start_args.running_max_req_size * 2}"
+ )
+
+ conv_kernel_size = self.config["linear_conv_kernel_dim"]
+ conv_dim = (
+ self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads
+ )
+
+ ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32}
+ if start_args.mamba_ssm_data_type not in ssm_dtype_dict:
+ raise ValueError(
+ f"Invalid mamba_ssm_data_type: {start_args.mamba_ssm_data_type}."
+ f" Must be one of {list(ssm_dtype_dict.keys())}"
+ )
+
+ self.mem_manager = Qwen3NextHybridMemManager(
+ full_attn_cache_size=self.max_total_token_num,
+ linear_attn_cache_size=mamba_cache_size,
+ dtype=self.data_type,
+ num_kv_heads=self.num_kv_heads,
+ head_dim=self.config["head_dim"],
+ layer_num=self.config["n_layer"],
+ mtp_layer_num=start_args.mtp_step,
+ full_attention_interval=self.config["full_attention_interval"],
+ conv_state_dtype=self.data_type,
+ conv_state_shape=(conv_dim // self.tp_world_size_, conv_kernel_size - 1),
+ ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type],
+ ssm_state_shape=(
+ self.num_linear_v_heads // self.tp_world_size_,
+ self.head_linear_k_dim,
+ self.head_linear_v_dim,
+ ),
+ max_req_num=self.max_req_num,
+ mem_fraction=self.mem_fraction,
+ )
+
+ def _init_req_manager(self):
+ create_max_seq_len = 0
+
+ if self.batch_max_tokens is not None:
+ create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens)
+ if self.max_seq_length is not None:
+ create_max_seq_len = max(create_max_seq_len, self.max_seq_length)
+
+ self.req_manager = ReqManagerForMamba(self.max_req_num, create_max_seq_len, self.mem_manager)
diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py
new file mode 100644
index 0000000000..c6d099a2d8
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py
@@ -0,0 +1,122 @@
+# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d.py
+
+from typing import Optional
+
+import torch
+
+from sgl_kernel import causal_conv1d_fwd
+from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel
+
+
+def causal_conv1d_fn(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ query_start_loc: Optional[torch.Tensor] = None,
+ cache_indices: Optional[torch.Tensor] = None,
+ has_initial_state: Optional[torch.Tensor] = None,
+ conv_states: Optional[torch.Tensor] = None,
+ activation: Optional[str] = "silu",
+ pad_slot_id: int = -1,
+ **kwargs,
+):
+ """
+ x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen
+ sequences are concatenated from left to right for varlen
+ weight: (dim, width)
+ bias: (dim,)
+ query_start_loc: (batch + 1) int32
+ The cumulative sequence lengths of the sequences in
+ the batch, used to index into sequence. prepended by 0.
+ for example: query_start_loc = torch.Tensor([0,10,16,17]),
+ x.shape=(dim,17)
+ cache_indices: (batch) int32
+ indicates the corresponding state index,
+ like so: conv_state = conv_states[cache_indices[batch_id]]
+ has_initial_state: (batch) bool
+ indicates whether should the kernel take the current state as initial
+ state for the calculations
+ conv_states: (...,dim,width - 1) itype
+ updated inplace if provided
+ activation: either None or "silu" or "swish"
+ pad_slot_id: int
+ if cache_indices is passed, lets the kernel identify padded
+ entries that will not be processed,
+ for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id]
+ in this case, the kernel will not process entries at
+ indices 0 and 3
+
+
+ out: (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError("activation must be None, silu, or swish")
+ if x.stride(-1) != 1:
+ x = x.contiguous()
+ bias = bias.contiguous() if bias is not None else None
+
+ causal_conv1d_fwd(
+ x,
+ weight,
+ bias,
+ conv_states,
+ query_start_loc,
+ cache_indices,
+ has_initial_state,
+ activation in ["silu", "swish"],
+ pad_slot_id,
+ )
+ return x
+
+
+def causal_conv1d_update(
+ x: torch.Tensor,
+ conv_state: torch.Tensor,
+ weight: torch.Tensor,
+ bias: Optional[torch.Tensor] = None,
+ activation: Optional[str] = None,
+ cache_seqlens: Optional[torch.Tensor] = None,
+ conv_state_indices: Optional[torch.Tensor] = None,
+ pad_slot_id: int = -1,
+):
+ """
+ x: (batch, dim) or (batch, dim, seqlen)
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
+ weight: (dim, width)
+ bias: (dim,)
+ cache_seqlens: (batch,), dtype int32.
+ If not None, the conv_state is treated as a circular buffer.
+ The conv_state will be updated by copying x to the conv_state
+ starting at the index
+ @cache_seqlens % state_len.
+ conv_state_indices: (batch,), dtype int32
+ If not None, the conv_state is a larger tensor along the batch dim,
+ and we are selecting the batch coords specified by conv_state_indices.
+ Useful for a continuous batching scenario.
+ pad_slot_id: int
+ if cache_indices is passed, lets the kernel identify padded
+ entries that will not be processed,
+ for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id]
+ in this case, the kernel will not process entries at
+ indices 0 and 3
+ out: (batch, dim) or (batch, dim, seqlen)
+ """
+ if activation not in [None, "silu", "swish"]:
+ raise NotImplementedError(f"activation must be None, silu, or swish, actual: {activation}")
+ activation_val = activation in ["silu", "swish"]
+ unsqueeze = x.dim() == 2
+ if unsqueeze:
+ x = x.unsqueeze(-1)
+ causal_conv1d_update_kernel(
+ x,
+ conv_state,
+ weight,
+ bias,
+ activation_val,
+ cache_seqlens,
+ conv_state_indices,
+ pad_slot_id,
+ )
+ if unsqueeze:
+ x = x.squeeze(-1)
+ return x
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py
new file mode 100644
index 0000000000..2bde70bb99
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py
@@ -0,0 +1,11 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+
+# Adapted from
+# https://github.com/vllm-project/vllm
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py
new file mode 100644
index 0000000000..cd3b0962a3
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py
@@ -0,0 +1,15 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+from .chunk import chunk_gated_delta_rule
+from .fused_recurrent import fused_recurrent_gated_delta_rule
+
+__all__ = [
+ "chunk_gated_delta_rule",
+ "fused_recurrent_gated_delta_rule",
+]
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py
new file mode 100644
index 0000000000..7b3067bbfb
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py
@@ -0,0 +1,224 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+# ruff: noqa: E501
+import torch
+from einops import rearrange
+
+from .chunk_delta_h import chunk_gated_delta_rule_fwd_h
+from .chunk_o import chunk_fwd_o
+from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd
+from .cumsum import chunk_local_cumsum
+from .l2norm import l2norm_fwd
+from .solve_tril import solve_tril
+from .utils import SUPPRESS_LEVEL, input_guard
+from .wy_fast import recompute_w_u_fwd
+
+
+def chunk_gated_delta_rule_fwd(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ scale: float,
+ initial_state: torch.Tensor,
+ output_final_state: bool,
+ cu_seqlens: torch.LongTensor | None = None,
+):
+ g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens)
+ # obtain WY representation. u is actually the new v.
+ A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=torch.float32)
+ A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype)
+ w, u = recompute_w_u_fwd(
+ k=k,
+ v=v,
+ beta=beta,
+ A=A,
+ g_cumsum=g,
+ cu_seqlens=cu_seqlens,
+ )
+ h, v_new, final_state = chunk_gated_delta_rule_fwd_h(
+ k=k,
+ w=w,
+ u=u,
+ g=g,
+ initial_state=initial_state,
+ output_final_state=output_final_state,
+ cu_seqlens=cu_seqlens,
+ chunk_size=64,
+ )
+ o = chunk_fwd_o(
+ q=q,
+ k=k,
+ v=v_new,
+ h=h,
+ g=g,
+ scale=scale,
+ cu_seqlens=cu_seqlens,
+ chunk_size=64,
+ )
+ if SUPPRESS_LEVEL < 3:
+ return g, o, A, final_state, None, None, None
+ elif SUPPRESS_LEVEL >= 3:
+ return g, o, A, final_state, w, h, v_new
+
+
+class ChunkGatedDeltaRuleFunction(torch.autograd.Function):
+ @staticmethod
+ @input_guard
+ @torch.amp.custom_fwd(device_type="cuda")
+ def forward(
+ ctx,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ scale: float,
+ initial_state: torch.Tensor,
+ output_final_state: bool,
+ cu_seqlens: torch.LongTensor | None = None,
+ use_qk_l2norm_in_kernel: bool = False,
+ ):
+ if use_qk_l2norm_in_kernel:
+ q = l2norm_fwd(q)
+ k = l2norm_fwd(k)
+
+ g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
+ q=q,
+ k=k,
+ v=v,
+ g=g,
+ beta=beta,
+ scale=scale,
+ initial_state=initial_state,
+ output_final_state=output_final_state,
+ cu_seqlens=cu_seqlens,
+ )
+ ctx.scale = scale
+ ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel
+ return o.to(q.dtype), final_state
+
+
+@torch.compiler.disable
+def chunk_gated_delta_rule(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ scale: float = None,
+ initial_state: torch.Tensor = None,
+ output_final_state: bool = False,
+ cu_seqlens: torch.LongTensor | None = None,
+ head_first: bool = False,
+ use_qk_l2norm_in_kernel: bool = False,
+):
+ r"""
+ Args:
+ q (torch.Tensor):
+ queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
+ k (torch.Tensor):
+ keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`.
+ v (torch.Tensor):
+ values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
+ g (torch.Tensor):
+ (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
+ beta (torch.Tensor):
+ betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`.
+ scale (Optional[int]):
+ Scale factor for the RetNet attention scores.
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
+ initial_state (Optional[torch.Tensor]):
+ Initial state of shape `[N, H, K, V]` for `N` input sequences.
+ For equal-length input sequences, `N` equals the batch size `B`.
+ Default: `None`.
+ output_final_state (Optional[bool]):
+ Whether to output the final state of shape `[N, H, K, V]`. Default: `False`.
+ cu_seqlens (torch.LongTensor):
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
+ consistent with the FlashAttention API.
+ head_first (Optional[bool]):
+ Whether the inputs are in the head-first format, which is not supported for variable-length inputs.
+ Default: `False`.
+
+ Returns:
+ o (torch.Tensor):
+ Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`.
+ final_state (torch.Tensor):
+ Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`.
+
+ Examples::
+ >>> import torch
+ >>> import torch.nn.functional as F
+ >>> from einops import rearrange
+ >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule
+ # inputs with equal lengths
+ >>> B, T, H, K, V = 4, 2048, 4, 512, 512
+ >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda')
+ >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1)
+ >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda')
+ >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid()
+ >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda'))
+ >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda')
+ >>> o, ht = chunk_gated_delta_rule(
+ q, k, v, g, beta,
+ initial_state=h0,
+ output_final_state=True
+ )
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
+ >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g))
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
+ >>> o_var, ht_var = chunk_gated_delta_rule(
+ q, k, v, g, beta,
+ initial_state=h0,
+ output_final_state=True,
+ cu_seqlens=cu_seqlens
+ )
+ """
+ assert q.dtype == k.dtype == v.dtype
+ assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16."
+ assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise."
+
+ if head_first:
+ raise DeprecationWarning(
+ "head_first is deprecated and will be removed in a future version. "
+ "Please use head_first=False for now instead.",
+ stacklevel=2,
+ )
+ q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g))
+ if cu_seqlens is not None:
+ if q.shape[0] != 1:
+ raise ValueError(
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
+ f"Please flatten variable-length inputs before processing."
+ )
+ if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1:
+ raise ValueError(
+ f"The number of initial states is expected to be equal to the number of input sequences, "
+ f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}."
+ )
+ if scale is None:
+ scale = k.shape[-1] ** -0.5
+ o, final_state = ChunkGatedDeltaRuleFunction.apply(
+ q,
+ k,
+ v,
+ g,
+ beta,
+ scale,
+ initial_state,
+ output_final_state,
+ cu_seqlens,
+ use_qk_l2norm_in_kernel,
+ )
+ if head_first:
+ o = rearrange(o, "b t h ... -> b h t ...")
+ return o, final_state
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py
new file mode 100644
index 0000000000..97933b2ac2
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py
@@ -0,0 +1,324 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+# ruff: noqa: E501
+
+import torch
+
+import triton
+import triton.language as tl
+
+from .index import prepare_chunk_indices, prepare_chunk_offsets
+from .op import exp, safe_exp
+from lightllm.common.triton_utils.autotuner import autotune
+
+NUM_WARPS = [2, 4, 8, 16]
+
+
+@triton.heuristics(
+ {
+ "USE_G": lambda args: args["g"] is not None,
+ "USE_GK": lambda args: args["gk"] is not None,
+ "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
+ "STORE_FINAL_STATE": lambda args: args["ht"] is not None,
+ "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None,
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
+ }
+)
+@triton.jit(do_not_specialize=["T"])
+def chunk_gated_delta_rule_fwd_kernel_h_blockdim64(
+ k,
+ v,
+ w,
+ v_new,
+ g,
+ gk,
+ h,
+ h0,
+ ht,
+ cu_seqlens,
+ chunk_offsets,
+ T,
+ H: tl.constexpr,
+ Hg: tl.constexpr,
+ K: tl.constexpr,
+ V: tl.constexpr,
+ BT: tl.constexpr,
+ BV: tl.constexpr,
+ USE_G: tl.constexpr,
+ USE_GK: tl.constexpr,
+ USE_INITIAL_STATE: tl.constexpr,
+ STORE_FINAL_STATE: tl.constexpr,
+ SAVE_NEW_VALUE: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+):
+ i_v, i_nh = tl.program_id(0), tl.program_id(1)
+ i_n, i_h = i_nh // H, i_nh % H
+ if IS_VARLEN:
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ NT = tl.cdiv(T, BT)
+ boh = tl.load(chunk_offsets + i_n).to(tl.int32)
+ else:
+ bos, eos = i_n * T, i_n * T + T
+ NT = tl.cdiv(T, BT)
+ boh = i_n * NT
+
+ # [BK, BV]
+ b_h1 = tl.zeros([64, BV], dtype=tl.float32)
+ if K > 64:
+ b_h2 = tl.zeros([64, BV], dtype=tl.float32)
+ if K > 128:
+ b_h3 = tl.zeros([64, BV], dtype=tl.float32)
+ if K > 192:
+ b_h4 = tl.zeros([64, BV], dtype=tl.float32)
+
+ # calculate offset
+ h += ((boh * H + i_h) * K * V).to(tl.int64)
+ v += ((bos * H + i_h) * V).to(tl.int64)
+ k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64)
+ w += ((bos * H + i_h) * K).to(tl.int64)
+ if SAVE_NEW_VALUE:
+ v_new += ((bos * H + i_h) * V).to(tl.int64)
+ stride_v = H * V
+ stride_h = H * K * V
+ stride_k = Hg * K
+ stride_w = H * K
+ if USE_INITIAL_STATE:
+ h0 = h0 + i_nh * K * V
+ if STORE_FINAL_STATE:
+ ht = ht + i_nh * K * V
+
+ # load initial state
+ if USE_INITIAL_STATE:
+ p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
+ b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32)
+ if K > 64:
+ p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
+ b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32)
+ if K > 128:
+ p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
+ b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32)
+ if K > 192:
+ p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
+ b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32)
+
+ # main recurrence
+ for i_t in range(NT):
+ p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
+ tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1))
+ if K > 64:
+ p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
+ tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1))
+ if K > 128:
+ p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
+ tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1))
+ if K > 192:
+ p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
+ tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1))
+
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0))
+ b_w = tl.load(p_w, boundary_check=(0, 1))
+ b_v = tl.dot(b_w, b_h1.to(b_w.dtype))
+ if K > 64:
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0))
+ b_w = tl.load(p_w, boundary_check=(0, 1))
+ b_v += tl.dot(b_w, b_h2.to(b_w.dtype))
+ if K > 128:
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0))
+ b_w = tl.load(p_w, boundary_check=(0, 1))
+ b_v += tl.dot(b_w, b_h3.to(b_w.dtype))
+ if K > 192:
+ p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0))
+ b_w = tl.load(p_w, boundary_check=(0, 1))
+ b_v += tl.dot(b_w, b_h4.to(b_w.dtype))
+ p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
+ b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v
+
+ if SAVE_NEW_VALUE:
+ p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
+ tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1))
+
+ last_idx = min((i_t + 1) * BT, T) - 1
+ if USE_G:
+ b_g_last = tl.load(g + bos * H + last_idx * H + i_h)
+ p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
+ b_g = tl.load(p_g, boundary_check=(0,))
+ b_v = b_v * safe_exp(b_g_last - b_g)[:, None]
+ b_g_last = exp(b_g_last)
+ b_h1 = b_h1 * b_g_last
+ if K > 64:
+ b_h2 = b_h2 * b_g_last
+ if K > 128:
+ b_h3 = b_h3 * b_g_last
+ if K > 192:
+ b_h4 = b_h4 * b_g_last
+
+ if USE_GK:
+ o_k1 = tl.arange(0, 64)
+ b_gk_last1 = tl.load(
+ gk + (bos + last_idx) * H * K + i_h * K + o_k1,
+ mask=(o_k1 < K),
+ other=0.0,
+ )
+ b_h1 *= exp(b_gk_last1)[:, None]
+ if K > 64:
+ o_k2 = 64 + o_k1
+ b_gk_last2 = tl.load(
+ gk + (bos + last_idx) * H * K + i_h * K + o_k2,
+ mask=(o_k2 < K),
+ other=0.0,
+ )
+ b_h2 *= exp(b_gk_last2)[:, None]
+ if K > 128:
+ o_k3 = 128 + o_k1
+ b_gk_last3 = tl.load(
+ gk + (bos + last_idx) * H * K + i_h * K + o_k3,
+ mask=(o_k3 < K),
+ other=0.0,
+ )
+ b_h3 *= exp(b_gk_last3)[:, None]
+ if K > 192:
+ o_k4 = 192 + o_k1
+ b_gk_last4 = tl.load(
+ gk + (bos + last_idx) * H * K + i_h * K + o_k4,
+ mask=(o_k4 < K),
+ other=0.0,
+ )
+ b_h4 *= exp(b_gk_last4)[:, None]
+ b_v = b_v.to(k.dtype.element_ty)
+
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1))
+ b_k = tl.load(p_k, boundary_check=(0, 1))
+ b_h1 += tl.dot(b_k, b_v)
+ if K > 64:
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1))
+ b_k = tl.load(p_k, boundary_check=(0, 1))
+ b_h2 += tl.dot(b_k, b_v)
+ if K > 128:
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1))
+ b_k = tl.load(p_k, boundary_check=(0, 1))
+ b_h3 += tl.dot(b_k, b_v)
+ if K > 192:
+ p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1))
+ b_k = tl.load(p_k, boundary_check=(0, 1))
+ b_h4 += tl.dot(b_k, b_v)
+ # epilogue
+ if STORE_FINAL_STATE:
+ p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0))
+ tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
+ if K > 64:
+ p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0))
+ tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
+ if K > 128:
+ p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0))
+ tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
+ if K > 192:
+ p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0))
+ tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1))
+
+
+def _get_chunk_delta_h_configs():
+ return [
+ {"BV": BV, "num_warps": num_warps, "num_stages": num_stages}
+ for num_warps in [2, 4]
+ for num_stages in [2, 3, 4]
+ for BV in [32, 64]
+ ]
+
+
+def _get_chunk_delta_h_static_key(k, u, chunk_size):
+ B, T, Hg, K = k.shape
+ V = u.shape[-1]
+ H = u.shape[-2]
+ return {"H": H, "K": K, "V": V, "BT": chunk_size}
+
+
+def _get_chunk_delta_h_run_key(k, u):
+ # Return batch * heads as run key
+ return k.shape[0] * k.shape[2]
+
+
+@autotune(
+ kernel_name="chunk_gated_delta_rule_fwd_h",
+ configs_gen_func=_get_chunk_delta_h_configs,
+ static_key_func=_get_chunk_delta_h_static_key,
+ run_key_func=_get_chunk_delta_h_run_key,
+)
+def chunk_gated_delta_rule_fwd_h(
+ k: torch.Tensor,
+ w: torch.Tensor,
+ u: torch.Tensor,
+ g: torch.Tensor | None = None,
+ gk: torch.Tensor | None = None,
+ initial_state: torch.Tensor | None = None,
+ output_final_state: bool = False,
+ chunk_size: int = 64, # SY: remove this argument and force chunk size 64?
+ save_new_value: bool = True,
+ cu_seqlens: torch.LongTensor | None = None,
+ run_config=None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ # This kernel is slightly different from fla to support Q/K with different head numbers.
+ # In fla, Q/K always have the same head number, so Hg is always equal to H.
+ B, T, Hg, K, V = *k.shape, u.shape[-1]
+ H = u.shape[-2]
+ BT = chunk_size
+
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
+ # N: the actual number of sequences in the batch with either equal or variable lengths
+ if cu_seqlens is None:
+ N, NT, chunk_offsets = B, triton.cdiv(T, BT), None
+ else:
+ N, NT, chunk_offsets = (
+ len(cu_seqlens) - 1,
+ len(chunk_indices),
+ prepare_chunk_offsets(cu_seqlens, BT),
+ )
+ assert K <= 256, "current kernel does not support head dimension larger than 256."
+
+ h = k.new_empty(B, NT, H, K, V)
+ final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None
+
+ v_new = torch.empty_like(u) if save_new_value else None
+
+ # Extract config parameters
+ if run_config is None:
+ run_config = {"BV": 64, "num_warps": 2, "num_stages": 2}
+
+ BV = run_config.get("BV", 64)
+ num_warps = run_config.get("num_warps", 2)
+ num_stages = run_config.get("num_stages", 2)
+
+ grid = (triton.cdiv(V, BV), N * H)
+
+ chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid](
+ k=k,
+ v=u,
+ w=w,
+ v_new=v_new,
+ g=g,
+ gk=gk,
+ h=h,
+ h0=initial_state,
+ ht=final_state,
+ cu_seqlens=cu_seqlens,
+ chunk_offsets=chunk_offsets,
+ T=T,
+ H=H,
+ Hg=Hg,
+ K=K,
+ V=V,
+ BT=BT,
+ BV=BV,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return h, v_new, final_state
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py
new file mode 100644
index 0000000000..fc49763ecd
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py
@@ -0,0 +1,205 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+
+# ruff: noqa: E501
+
+
+import torch
+
+import triton
+import triton.language as tl
+
+from .index import prepare_chunk_indices
+from .op import exp, safe_exp
+from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper
+from lightllm.common.triton_utils.autotuner import autotune
+
+BKV_LIST = [64, 128] if check_shared_mem() else [32, 64]
+NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8]
+
+
+@triton.heuristics(
+ {
+ "USE_G": lambda args: args["g"] is not None,
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
+ }
+)
+@triton.jit(do_not_specialize=["T"])
+def chunk_fwd_kernel_o(
+ q,
+ k,
+ v,
+ h,
+ g,
+ o,
+ cu_seqlens,
+ chunk_indices,
+ scale,
+ T,
+ H: tl.constexpr,
+ Hg: tl.constexpr,
+ K: tl.constexpr,
+ V: tl.constexpr,
+ BT: tl.constexpr,
+ BK: tl.constexpr,
+ BV: tl.constexpr,
+ USE_G: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+):
+ i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
+ i_b, i_h = i_bh // H, i_bh % H
+
+ if IS_VARLEN:
+ i_tg = i_t
+ i_n, i_t = (
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
+ )
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ NT = tl.cdiv(T, BT)
+ else:
+ NT = tl.cdiv(T, BT)
+ i_tg = i_b * NT + i_t
+ bos, eos = i_b * T, i_b * T + T
+
+ # offset calculation
+ q += (bos * Hg + i_h // (H // Hg)) * K
+ k += (bos * Hg + i_h // (H // Hg)) * K
+ v += (bos * H + i_h) * V
+ o += (bos * H + i_h) * V
+ h += (i_tg * H + i_h).to(tl.int64) * K * V
+
+ b_o = tl.zeros([BT, BV], dtype=tl.float32)
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
+
+ for i_k in range(tl.cdiv(K, BK)):
+ p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))
+ p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1))
+ p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))
+ # [BT, BK]
+ b_q = tl.load(p_q, boundary_check=(0, 1))
+ # [BK, BT]
+ b_k = tl.load(p_k, boundary_check=(0, 1))
+ # [BK, BV]
+ b_h = tl.load(p_h, boundary_check=(0, 1))
+
+ # [BT, BK] @ [BK, BV] -> [BT, BV]
+ b_o += tl.dot(b_q, b_h)
+ # [BT, BK] @ [BK, BT] -> [BT, BT]
+ b_A += tl.dot(b_q, b_k)
+
+ if USE_G:
+ g += bos * H + i_h
+ p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,))
+ b_g = tl.load(p_g, boundary_check=(0,))
+ b_o = b_o * exp(b_g)[:, None]
+ b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :])
+
+ o_t = i_t * BT + tl.arange(0, BT)
+ m_t = o_t < T
+ m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t)
+ b_A = tl.where(m_A, b_A, 0)
+
+ p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
+ p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))
+ b_v = tl.load(p_v, boundary_check=(0, 1))
+
+ # to fix mma -> mma layout conversion
+ # already solved by triton v3.2 or higher
+ b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
+
+
+def _get_chunk_o_configs():
+ return [
+ {"BK": BK, "BV": BV, "num_warps": num_warps, "num_stages": num_stages}
+ for BK in BKV_LIST
+ for BV in BKV_LIST
+ for num_warps in NUM_WARPS
+ for num_stages in [2, 3, 4]
+ ]
+
+
+def _get_chunk_o_static_key(q, v, chunk_size):
+ B, T, Hg, K = q.shape
+ V = v.shape[-1]
+ H = v.shape[-2]
+ BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T)))
+ return {"H": H, "K": K, "V": V, "BT": BT}
+
+
+def _get_chunk_o_run_key(q, v):
+ # Return batch * heads as run key
+ return q.shape[0] * q.shape[2]
+
+
+@autotune(
+ kernel_name="chunk_fwd_o",
+ configs_gen_func=_get_chunk_o_configs,
+ static_key_func=_get_chunk_o_static_key,
+ run_key_func=_get_chunk_o_run_key,
+)
+def chunk_fwd_o(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ h: torch.Tensor,
+ g: torch.Tensor | None = None, # cumsum of log decay
+ scale: float | None = None,
+ cu_seqlens: torch.LongTensor | None = None,
+ chunk_size: int = 64,
+ run_config=None,
+) -> torch.Tensor:
+ B, T, Hg, K, V = *q.shape, v.shape[-1]
+ H = v.shape[-2]
+ BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T)))
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
+ if scale is None:
+ scale = k.shape[-1] ** -0.5
+
+ o = torch.empty_like(v)
+
+ # Extract config parameters
+ if run_config is None:
+ run_config = {"BK": 64, "BV": 64, "num_warps": 2, "num_stages": 2}
+
+ BK = run_config.get("BK", 64)
+ BV = run_config.get("BV", 64)
+ num_warps = run_config.get("num_warps", 2)
+ num_stages = run_config.get("num_stages", 2)
+
+ grid = (triton.cdiv(V, BV), NT, B * H)
+
+ chunk_fwd_kernel_o[grid](
+ q,
+ k,
+ v,
+ h,
+ g,
+ o,
+ cu_seqlens,
+ chunk_indices,
+ scale,
+ T=T,
+ H=H,
+ Hg=Hg,
+ K=K,
+ V=V,
+ BT=BT,
+ BK=BK,
+ BV=BV,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return o
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py
new file mode 100644
index 0000000000..60a594c078
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py
@@ -0,0 +1,180 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+# ruff: noqa: E501
+
+import torch
+
+import triton
+import triton.language as tl
+
+from .index import prepare_chunk_indices
+from .op import safe_exp
+from lightllm.common.triton_utils.autotuner import autotune
+
+
+@triton.heuristics(
+ {
+ "USE_G": lambda args: args["g"] is not None,
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
+ }
+)
+@triton.jit(do_not_specialize=["T"])
+def chunk_scaled_dot_kkt_fwd_kernel(
+ k,
+ beta,
+ g,
+ A,
+ cu_seqlens,
+ chunk_indices,
+ T,
+ H: tl.constexpr,
+ Hg: tl.constexpr,
+ K: tl.constexpr,
+ BT: tl.constexpr,
+ BK: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ USE_G: tl.constexpr,
+):
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
+ i_b, i_h = i_bh // H, i_bh % H
+ if IS_VARLEN:
+ i_n, i_t = (
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
+ )
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ else:
+ bos, eos = i_b * T, i_b * T + T
+ o_t = i_t * BT + tl.arange(0, BT)
+ m_t = o_t < T
+
+ p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
+ b_beta = tl.load(p_beta, boundary_check=(0,))
+
+ b_A = tl.zeros([BT, BT], dtype=tl.float32)
+ for i_k in range(tl.cdiv(K, BK)):
+ p_k = tl.make_block_ptr(
+ k + (bos * Hg + i_h // (H // Hg)) * K,
+ (T, K),
+ (Hg * K, 1),
+ (i_t * BT, i_k * BK),
+ (BT, BK),
+ (1, 0),
+ )
+ b_k = tl.load(p_k, boundary_check=(0, 1))
+ b_A += tl.dot(b_k, tl.trans(b_k))
+
+ if USE_G:
+ p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
+ b_g = tl.load(p_g, boundary_check=(0,))
+ b_g_diff = b_g[:, None] - b_g[None, :]
+ b_A = b_A * safe_exp(b_g_diff)
+
+ b_A *= b_beta[:, None]
+ m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t)
+ b_A = tl.where(m_A, b_A, 0)
+ p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0))
+ tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1))
+
+
+def _get_chunk_scaled_dot_kkt_configs():
+ return [
+ {"BK": BK, "num_warps": num_warps, "num_stages": num_stages}
+ for BK in [32, 64, 128]
+ for num_warps in [2, 4, 8]
+ for num_stages in [2, 3, 4]
+ ]
+
+
+def _get_chunk_scaled_dot_kkt_static_key(k, beta, chunk_size=64, cu_seqlens=None):
+ B, T, Hg, K = k.shape
+ H = beta.shape[-1]
+ IS_VARLEN = cu_seqlens is not None
+ return {"H": H, "K": K, "BT": chunk_size, "IS_VARLEN": IS_VARLEN}
+
+
+def _get_chunk_scaled_dot_kkt_run_key(k, beta):
+ # Return batch * heads as run key
+ return k.shape[0] * k.shape[2]
+
+
+@autotune(
+ kernel_name="chunk_scaled_dot_kkt_fwd",
+ configs_gen_func=_get_chunk_scaled_dot_kkt_configs,
+ static_key_func=_get_chunk_scaled_dot_kkt_static_key,
+ run_key_func=_get_chunk_scaled_dot_kkt_run_key,
+)
+def chunk_scaled_dot_kkt_fwd(
+ k: torch.Tensor,
+ g: torch.Tensor | None = None,
+ beta: torch.Tensor | None = None,
+ cu_seqlens: torch.LongTensor | None = None,
+ chunk_size: int = 64,
+ output_dtype: torch.dtype = torch.float32,
+ run_config=None,
+) -> torch.Tensor:
+ r"""
+ Compute beta * K * K^T.
+
+ Args:
+ k (torch.Tensor):
+ The key tensor of shape `[B, T, H, K]`.
+ beta (torch.Tensor):
+ The beta tensor of shape `[B, T, H]`.
+ g (torch.Tensor):
+ The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`.
+ cu_seqlens (torch.LongTensor):
+ The cumulative sequence lengths of the input tensor.
+ Default: None
+ chunk_size (int):
+ The chunk size. Default: 64.
+ output_dtype (torch.dtype):
+ The dtype of the output tensor. Default: `torch.float32`
+
+ Returns:
+ beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size.
+ """
+ # This kernel is slightly different from fla to support Q/K with different head numbers.
+ # In fla, Q/K always have the same head number, so Hg is always equal to H.
+ B, T, Hg, K = k.shape
+ H = beta.shape[-1]
+ BT = chunk_size
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
+
+ # Extract config parameters
+ if run_config is None:
+ run_config = {"BK": 64, "num_warps": 2, "num_stages": 2}
+
+ BK = run_config.get("BK", 64)
+ num_warps = run_config.get("num_warps", 2)
+ num_stages = run_config.get("num_stages", 2)
+
+ A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype)
+ chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)](
+ k=k,
+ g=g,
+ beta=beta,
+ A=A,
+ cu_seqlens=cu_seqlens,
+ chunk_indices=chunk_indices,
+ T=T,
+ H=H,
+ Hg=Hg,
+ K=K,
+ BT=BT,
+ BK=BK,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ return A
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py
new file mode 100644
index 0000000000..6331e1602d
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py
@@ -0,0 +1,306 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+# ruff: noqa: E501
+import torch
+
+import triton
+import triton.language as tl
+
+from .index import prepare_chunk_indices
+from .utils import check_shared_mem, input_guard
+from lightllm.common.triton_utils.autotuner import autotune
+
+BS_LIST = [32, 64] if check_shared_mem() else [16, 32]
+
+
+@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
+@triton.jit(do_not_specialize=["T"])
+def chunk_local_cumsum_scalar_kernel(
+ s,
+ o,
+ cu_seqlens,
+ chunk_indices,
+ T,
+ B: tl.constexpr,
+ H: tl.constexpr,
+ BT: tl.constexpr,
+ REVERSE: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ HEAD_FIRST: tl.constexpr,
+):
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
+ i_b, i_h = i_bh // H, i_bh % H
+ if IS_VARLEN:
+ i_n, i_t = (
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
+ )
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ else:
+ bos, eos = i_b * T, i_b * T + T
+
+ if HEAD_FIRST:
+ p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
+ p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,))
+ else:
+ p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
+ p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
+ # [BT]
+ b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)
+ b_o = tl.cumsum(b_s, axis=0)
+ if REVERSE:
+ b_z = tl.sum(b_s, axis=0)
+ b_o = -b_o + b_z[None] + b_s
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))
+
+
+@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
+@triton.jit(do_not_specialize=["T"])
+def chunk_local_cumsum_vector_kernel(
+ s,
+ o,
+ cu_seqlens,
+ chunk_indices,
+ T,
+ B: tl.constexpr,
+ H: tl.constexpr,
+ S: tl.constexpr,
+ BT: tl.constexpr,
+ BS: tl.constexpr,
+ REVERSE: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ HEAD_FIRST: tl.constexpr,
+):
+ i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
+ i_b, i_h = i_bh // H, i_bh % H
+ if IS_VARLEN:
+ i_n, i_t = (
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
+ )
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ else:
+ bos, eos = i_b * T, i_b * T + T
+
+ o_i = tl.arange(0, BT)
+ if REVERSE:
+ m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0)
+ else:
+ m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0)
+
+ if HEAD_FIRST:
+ p_s = tl.make_block_ptr(
+ s + (bos * H + i_h * T) * S,
+ (T, S),
+ (S, 1),
+ (i_t * BT, i_s * BS),
+ (BT, BS),
+ (1, 0),
+ )
+ p_o = tl.make_block_ptr(
+ o + (bos * H + i_h * T) * S,
+ (T, S),
+ (S, 1),
+ (i_t * BT, i_s * BS),
+ (BT, BS),
+ (1, 0),
+ )
+ else:
+ p_s = tl.make_block_ptr(
+ s + (bos * H + i_h) * S,
+ (T, S),
+ (H * S, 1),
+ (i_t * BT, i_s * BS),
+ (BT, BS),
+ (1, 0),
+ )
+ p_o = tl.make_block_ptr(
+ o + (bos * H + i_h) * S,
+ (T, S),
+ (H * S, 1),
+ (i_t * BT, i_s * BS),
+ (BT, BS),
+ (1, 0),
+ )
+ # [BT, BS]
+ b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)
+ b_o = tl.dot(m_s, b_s, allow_tf32=False)
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))
+
+
+def _get_cumsum_scalar_configs():
+ return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8]]
+
+
+def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first):
+ if head_first:
+ B, H, T = g.shape
+ else:
+ B, T, H = g.shape
+ IS_VARLEN = cu_seqlens is not None
+ return {"B": B, "H": H, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse}
+
+
+def _get_cumsum_scalar_run_key(g):
+ # Return total number of elements as run key
+ return g.shape[0] * g.shape[1]
+
+
+@autotune(
+ kernel_name="chunk_local_cumsum_scalar",
+ configs_gen_func=_get_cumsum_scalar_configs,
+ static_key_func=_get_cumsum_scalar_static_key,
+ run_key_func=_get_cumsum_scalar_run_key,
+)
+def chunk_local_cumsum_scalar(
+ g: torch.Tensor,
+ chunk_size: int,
+ reverse: bool = False,
+ cu_seqlens: torch.Tensor | None = None,
+ head_first: bool = False,
+ output_dtype: torch.dtype | None = torch.float,
+ run_config=None,
+) -> torch.Tensor:
+ if head_first:
+ B, H, T = g.shape
+ else:
+ B, T, H = g.shape
+ assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
+ BT = chunk_size
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
+
+ # Extract config parameters
+ if run_config is None:
+ run_config = {"num_warps": 2}
+
+ num_warps = run_config.get("num_warps", 2)
+
+ grid = (NT, B * H)
+ chunk_local_cumsum_scalar_kernel[grid](
+ g_org,
+ g,
+ cu_seqlens,
+ chunk_indices,
+ T=T,
+ B=B,
+ H=H,
+ BT=BT,
+ HEAD_FIRST=head_first,
+ REVERSE=reverse,
+ num_warps=num_warps,
+ )
+ return g
+
+
+def _get_cumsum_vector_configs():
+ return [{"BS": BS, "num_warps": num_warps} for BS in BS_LIST for num_warps in [2, 4, 8]]
+
+
+def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first):
+ if head_first:
+ B, H, T, S = g.shape
+ else:
+ B, T, H, S = g.shape
+ IS_VARLEN = cu_seqlens is not None
+ return {"B": B, "H": H, "S": S, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse}
+
+
+def _get_cumsum_vector_run_key(g):
+ # Return batch * heads as run key
+ return g.shape[0] * g.shape[2] if len(g.shape) == 4 else g.shape[0]
+
+
+@autotune(
+ kernel_name="chunk_local_cumsum_vector",
+ configs_gen_func=_get_cumsum_vector_configs,
+ static_key_func=_get_cumsum_vector_static_key,
+ run_key_func=_get_cumsum_vector_run_key,
+)
+def chunk_local_cumsum_vector(
+ g: torch.Tensor,
+ chunk_size: int,
+ reverse: bool = False,
+ cu_seqlens: torch.Tensor | None = None,
+ head_first: bool = False,
+ output_dtype: torch.dtype | None = torch.float,
+ run_config=None,
+) -> torch.Tensor:
+ if head_first:
+ B, H, T, S = g.shape
+ else:
+ B, T, H, S = g.shape
+ BT = chunk_size
+ chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
+ assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2"
+
+ g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype)
+
+ # Extract config parameters
+ if run_config is None:
+ run_config = {"BS": 32, "num_warps": 2}
+
+ BS = run_config.get("BS", 32)
+ num_warps = run_config.get("num_warps", 2)
+
+ grid = (triton.cdiv(S, BS), NT, B * H)
+
+ # keep cumulative normalizer in fp32
+ # this kernel is equivalent to
+ # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)
+ chunk_local_cumsum_vector_kernel[grid](
+ g_org,
+ g,
+ cu_seqlens,
+ chunk_indices,
+ T=T,
+ B=B,
+ H=H,
+ S=S,
+ BT=BT,
+ BS=BS,
+ HEAD_FIRST=head_first,
+ REVERSE=reverse,
+ num_warps=num_warps,
+ )
+ return g
+
+
+@input_guard
+def chunk_local_cumsum(
+ g: torch.Tensor,
+ chunk_size: int,
+ reverse: bool = False,
+ cu_seqlens: torch.Tensor | None = None,
+ head_first: bool = False,
+ output_dtype: torch.dtype | None = torch.float,
+ **kwargs,
+) -> torch.Tensor:
+ if cu_seqlens is not None:
+ assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided"
+ if len(g.shape) == 3:
+ return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
+ elif len(g.shape) == 4:
+ return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype)
+ else:
+ raise ValueError(
+ f"Unsupported input shape {g.shape}. "
+ f"which should be (B, T, H, D) if `head_first=False` "
+ f"or (B, H, T, D) otherwise"
+ )
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py
new file mode 100644
index 0000000000..22a93a2c99
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py
@@ -0,0 +1,492 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+# ruff: noqa: E501
+
+import torch
+
+import triton
+import triton.language as tl
+
+from .op import exp
+
+
+@triton.heuristics(
+ {
+ "USE_INITIAL_STATE": lambda args: args["h0"] is not None,
+ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None,
+ "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None,
+ "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None,
+ "HAS_SEPARATE_WRITE_INDICES": lambda args: args["ssm_state_write_indices"] is not None,
+ }
+)
+@triton.jit(do_not_specialize=["N", "T"])
+def fused_recurrent_gated_delta_rule_fwd_kernel(
+ q,
+ k,
+ v,
+ g,
+ beta,
+ o,
+ h0,
+ ht,
+ cu_seqlens,
+ ssm_state_indices,
+ ssm_state_write_indices, # NEW: separate write indices for state propagation optimization
+ num_accepted_tokens,
+ # Fused gating parameters (only used when FUSE_GATING=True)
+ A_log, # [HV] per-head log decay
+ dt_bias, # [HV] per-head dt bias
+ a_raw, # [B*T, HV] raw alpha values (before softplus)
+ b_raw, # [B*T, HV] raw beta values (before sigmoid)
+ scale,
+ N: tl.int64, # num of sequences
+ T: tl.int64, # num of tokens
+ B: tl.constexpr,
+ H: tl.constexpr,
+ HV: tl.constexpr,
+ K: tl.constexpr,
+ V: tl.constexpr,
+ BK: tl.constexpr,
+ BV: tl.constexpr,
+ stride_init_state_token: tl.constexpr,
+ stride_final_state_token: tl.constexpr,
+ stride_indices_seq: tl.constexpr,
+ stride_indices_tok: tl.constexpr,
+ stride_write_indices_seq: tl.constexpr, # NEW: stride for write indices
+ stride_write_indices_tok: tl.constexpr, # NEW: stride for write indices
+ SOFTPLUS_BETA: tl.constexpr, # softplus beta parameter (default 1.0)
+ SOFTPLUS_THRESHOLD: tl.constexpr, # softplus threshold (default 20.0)
+ USE_INITIAL_STATE: tl.constexpr, # whether to use initial state
+ INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace
+ IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar,
+ USE_QK_L2NORM_IN_KERNEL: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ IS_CONTINUOUS_BATCHING: tl.constexpr,
+ IS_SPEC_DECODING: tl.constexpr,
+ IS_KDA: tl.constexpr,
+ HAS_SEPARATE_WRITE_INDICES: tl.constexpr, # NEW: whether to use separate write indices
+ FUSE_GATING: tl.constexpr, # whether to compute g/beta inline from raw values
+):
+ i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2)
+ i_n, i_hv = i_nh // HV, i_nh % HV
+ i_h = i_hv // (HV // H)
+ if IS_VARLEN:
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int64),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int64),
+ )
+ all = T
+ T = eos - bos
+ else:
+ bos, eos = i_n * T, i_n * T + T
+ all = B * T
+
+ if T == 0:
+ # no tokens to process for this sequence
+ return
+
+ o_k = i_k * BK + tl.arange(0, BK)
+ o_v = i_v * BV + tl.arange(0, BV)
+
+ p_q = q + (bos * H + i_h) * K + o_k
+ p_k = k + (bos * H + i_h) * K + o_k
+ p_v = v + (bos * HV + i_hv) * V + o_v
+ if FUSE_GATING:
+ # Fused gating: load per-head constants once, compute g/beta inline per token
+ b_A_log = tl.load(A_log + i_hv).to(tl.float32)
+ b_dt_bias = tl.load(dt_bias + i_hv).to(tl.float32)
+ p_a_raw = a_raw + bos * HV + i_hv
+ p_b_raw = b_raw + bos * HV + i_hv
+ else:
+ if IS_BETA_HEADWISE:
+ p_beta = beta + (bos * HV + i_hv) * V + o_v
+ else:
+ p_beta = beta + bos * HV + i_hv
+
+ if not IS_KDA:
+ p_g = g + bos * HV + i_hv
+ else:
+ p_gk = g + (bos * HV + i_hv) * K + o_k
+
+ p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v
+
+ mask_k = o_k < K
+ mask_v = o_v < V
+ mask_h = mask_k[:, None] & mask_v[None, :]
+
+ b_h = tl.zeros([BK, BV], dtype=tl.float32)
+ if USE_INITIAL_STATE:
+ if IS_CONTINUOUS_BATCHING:
+ if IS_SPEC_DECODING:
+ i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1
+ else:
+ i_t = 0
+ p_h0 = (
+ h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token
+ )
+ else:
+ p_h0 = h0 + bos * HV * K * V
+ p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
+ b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32)
+
+ for i_t in range(0, T):
+ b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32)
+ b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32)
+ b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32)
+
+ if USE_QK_L2NORM_IN_KERNEL:
+ b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6)
+ b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6)
+ b_q = b_q * scale
+ # [BK, BV]
+ if FUSE_GATING:
+ # Compute g = -exp(A_log) * softplus(a_raw + dt_bias) inline
+ b_a = tl.load(p_a_raw).to(tl.float32)
+ x = b_a + b_dt_bias
+ softplus_x = tl.where(
+ SOFTPLUS_BETA * x <= SOFTPLUS_THRESHOLD,
+ (1.0 / SOFTPLUS_BETA) * tl.log(1.0 + tl.exp(SOFTPLUS_BETA * x)),
+ x,
+ )
+ b_g = -tl.exp(b_A_log) * softplus_x
+ b_h *= exp(b_g)
+ # Compute beta = sigmoid(b_raw) inline
+ b_b = tl.load(p_b_raw).to(tl.float32)
+ b_beta = tl.sigmoid(b_b)
+ else:
+ if not IS_KDA:
+ b_g = tl.load(p_g).to(tl.float32)
+ b_h *= exp(b_g)
+ else:
+ b_gk = tl.load(p_gk).to(tl.float32)
+ b_h *= exp(b_gk[:, None])
+ if IS_BETA_HEADWISE:
+ b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32)
+ else:
+ b_beta = tl.load(p_beta).to(tl.float32)
+ # [BV]
+ b_v -= tl.sum(b_h * b_k[:, None], 0)
+ b_v *= b_beta
+ # [BK, BV]
+ b_h += b_k[:, None] * b_v[None, :]
+ # [BV]
+ b_o = tl.sum(b_h * b_q[:, None], 0)
+ tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v)
+
+ # keep the states for multi-query tokens
+ if INPLACE_FINAL_STATE:
+ # Use separate write indices if provided (for state propagation optimization)
+ # Otherwise fall back to read indices
+ if HAS_SEPARATE_WRITE_INDICES:
+ write_idx = tl.load(ssm_state_write_indices + i_n * stride_write_indices_seq + i_t).to(tl.int64)
+ else:
+ write_idx = tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64)
+ p_ht = ht + write_idx * stride_final_state_token
+ else:
+ p_ht = ht + (bos + i_t) * stride_final_state_token
+ p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :]
+ tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h)
+
+ p_q += H * K
+ p_k += H * K
+ p_o += HV * V
+ p_v += HV * V
+ if FUSE_GATING:
+ p_a_raw += HV
+ p_b_raw += HV
+ else:
+ if not IS_KDA:
+ p_g += HV
+ else:
+ p_gk += HV * K
+ p_beta += HV * (V if IS_BETA_HEADWISE else 1)
+
+
+def fused_recurrent_gated_delta_rule_fwd(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ scale: float,
+ initial_state: torch.Tensor,
+ inplace_final_state: bool = True,
+ cu_seqlens: torch.LongTensor | None = None,
+ ssm_state_indices: torch.Tensor | None = None,
+ ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices
+ num_accepted_tokens: torch.Tensor | None = None,
+ use_qk_l2norm_in_kernel: bool = False,
+ # Fused gating parameters
+ A_log: torch.Tensor | None = None,
+ dt_bias: torch.Tensor | None = None,
+ a_raw: torch.Tensor | None = None,
+ b_raw: torch.Tensor | None = None,
+ out: torch.Tensor | None = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ B, T, H, K, V = *k.shape, v.shape[-1]
+ HV = v.shape[2]
+ N = B if cu_seqlens is None else len(cu_seqlens) - 1
+ BK = triton.next_power_of_2(K)
+ if T == 1:
+ # Decode path: use larger BV to reduce kernel instances (4 blocks instead of 16)
+ # and more warps for better SM utilization at T=1 where there's no pipelining benefit
+ BV = min(triton.next_power_of_2(V), 32)
+ num_warps = 4
+ num_stages = 1
+ else:
+ # Prefill path: small BV for better pipelining across sequence length
+ BV = min(triton.next_power_of_2(V), 8)
+ num_warps = 1
+ num_stages = 3
+ NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)
+ assert NK == 1, "NK > 1 is not supported yet"
+
+ fuse_gating = A_log is not None
+
+ if out is not None:
+ o = out.unsqueeze(0) if out.ndim == v.ndim else out
+ else:
+ o = q.new_empty(NK, *v.shape)
+ if inplace_final_state:
+ final_state = initial_state
+ else:
+ final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype)
+
+ stride_init_state_token = initial_state.stride(0)
+ stride_final_state_token = final_state.stride(0)
+
+ # Strides for read indices
+ if ssm_state_indices is None:
+ stride_indices_seq, stride_indices_tok = 1, 1
+ elif ssm_state_indices.ndim == 1:
+ stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1
+ else:
+ stride_indices_seq, stride_indices_tok = ssm_state_indices.stride()
+
+ # Strides for write indices (if provided)
+ if ssm_state_write_indices is None:
+ stride_write_indices_seq, stride_write_indices_tok = 1, 1
+ elif ssm_state_write_indices.ndim == 1:
+ stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride(0), 1
+ else:
+ stride_write_indices_seq, stride_write_indices_tok = ssm_state_write_indices.stride()
+
+ grid = (NK, NV, N * HV)
+ fused_recurrent_gated_delta_rule_fwd_kernel[grid](
+ q=q,
+ k=k,
+ v=v,
+ g=g,
+ beta=beta,
+ o=o,
+ h0=initial_state,
+ ht=final_state,
+ cu_seqlens=cu_seqlens,
+ ssm_state_indices=ssm_state_indices,
+ ssm_state_write_indices=ssm_state_write_indices,
+ num_accepted_tokens=num_accepted_tokens,
+ A_log=A_log,
+ dt_bias=dt_bias,
+ a_raw=a_raw,
+ b_raw=b_raw,
+ scale=scale,
+ N=N,
+ T=T,
+ B=B,
+ H=H,
+ HV=HV,
+ K=K,
+ V=V,
+ BK=BK,
+ BV=BV,
+ stride_init_state_token=stride_init_state_token,
+ stride_final_state_token=stride_final_state_token,
+ stride_indices_seq=stride_indices_seq,
+ stride_indices_tok=stride_indices_tok,
+ stride_write_indices_seq=stride_write_indices_seq,
+ stride_write_indices_tok=stride_write_indices_tok,
+ SOFTPLUS_BETA=1.0,
+ SOFTPLUS_THRESHOLD=20.0,
+ IS_BETA_HEADWISE=False if fuse_gating else (beta.ndim == v.ndim),
+ USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel,
+ INPLACE_FINAL_STATE=inplace_final_state,
+ IS_KDA=False,
+ FUSE_GATING=fuse_gating,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+ o = o.squeeze(0)
+ return o, final_state
+
+
+class FusedRecurrentFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(
+ ctx,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ scale: float,
+ initial_state: torch.Tensor,
+ inplace_final_state: bool = True,
+ cu_seqlens: torch.LongTensor | None = None,
+ ssm_state_indices: torch.Tensor | None = None,
+ ssm_state_write_indices: torch.Tensor | None = None,
+ num_accepted_tokens: torch.Tensor | None = None,
+ use_qk_l2norm_in_kernel: bool = False,
+ A_log: torch.Tensor | None = None,
+ dt_bias: torch.Tensor | None = None,
+ a_raw: torch.Tensor | None = None,
+ b_raw: torch.Tensor | None = None,
+ out: torch.Tensor | None = None,
+ ):
+ o, final_state = fused_recurrent_gated_delta_rule_fwd(
+ q=q.contiguous(),
+ k=k.contiguous(),
+ v=v.contiguous(),
+ g=g.contiguous() if g is not None else None,
+ beta=beta.contiguous() if beta is not None else None,
+ scale=scale,
+ initial_state=initial_state,
+ inplace_final_state=inplace_final_state,
+ cu_seqlens=cu_seqlens,
+ ssm_state_indices=ssm_state_indices,
+ ssm_state_write_indices=ssm_state_write_indices,
+ num_accepted_tokens=num_accepted_tokens,
+ use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel,
+ A_log=A_log,
+ dt_bias=dt_bias,
+ a_raw=a_raw.contiguous() if a_raw is not None else None,
+ b_raw=b_raw.contiguous() if b_raw is not None else None,
+ out=out,
+ )
+
+ return o, final_state
+
+
+def fused_recurrent_gated_delta_rule(
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ g: torch.Tensor = None,
+ beta: torch.Tensor = None,
+ scale: float = None,
+ initial_state: torch.Tensor = None,
+ inplace_final_state: bool = True,
+ cu_seqlens: torch.LongTensor | None = None,
+ ssm_state_indices: torch.Tensor | None = None,
+ ssm_state_write_indices: torch.Tensor | None = None, # NEW: separate write indices for state propagation
+ num_accepted_tokens: torch.Tensor | None = None,
+ use_qk_l2norm_in_kernel: bool = False,
+ # Fused gating: pass raw values to compute g/beta inline in the kernel
+ A_log: torch.Tensor | None = None,
+ dt_bias: torch.Tensor | None = None,
+ a_raw: torch.Tensor | None = None,
+ b_raw: torch.Tensor | None = None,
+ out: torch.Tensor | None = None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ r"""
+ Args:
+ q (torch.Tensor):
+ queries of shape `[B, T, H, K]`.
+ k (torch.Tensor):
+ keys of shape `[B, T, H, K]`.
+ v (torch.Tensor):
+ values of shape `[B, T, HV, V]`.
+ GVA is applied if `HV > H`.
+ g (torch.Tensor):
+ g (decays) of shape `[B, T, HV]`.
+ beta (torch.Tensor):
+ betas of shape `[B, T, HV]`.
+ scale (Optional[int]):
+ Scale factor for the RetNet attention scores.
+ If not provided, it will default to `1 / sqrt(K)`. Default: `None`.
+ initial_state (Optional[torch.Tensor]):
+ Initial state of shape `[N, HV, K, V]` for `N` input sequences.
+ For equal-length input sequences, `N` equals the batch size `B`.
+ Default: `None`.
+ inplace_final_state: bool:
+ Whether to store the final state in-place to save memory.
+ Default: `True`.
+ cu_seqlens (torch.LongTensor):
+ Cumulative sequence lengths of shape `[N+1]` used for variable-length training,
+ consistent with the FlashAttention API.
+ ssm_state_indices (Optional[torch.Tensor]):
+ Indices to map the input sequences to the initial/final states.
+ num_accepted_tokens (Optional[torch.Tensor]):
+ Number of accepted tokens for each sequence during decoding.
+
+ Returns:
+ o (torch.Tensor):
+ Outputs of shape `[B, T, HV, V]`.
+ final_state (torch.Tensor):
+ Final state of shape `[N, HV, K, V]`.
+
+ Examples::
+ >>> import torch
+ >>> import torch.nn.functional as F
+ >>> from einops import rearrange
+ >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule
+ # inputs with equal lengths
+ >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512
+ >>> q = torch.randn(B, T, H, K, device='cuda')
+ >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1)
+ >>> v = torch.randn(B, T, HV, V, device='cuda')
+ >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda'))
+ >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid()
+ >>> h0 = torch.randn(B, HV, K, V, device='cuda')
+ >>> o, ht = fused_gated_recurrent_delta_rule(
+ q, k, v, g, beta,
+ initial_state=h0,
+ )
+ # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required
+ >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta))
+ # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected
+ >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long)
+ >>> o_var, ht_var = fused_gated_recurrent_delta_rule(
+ q, k, v, g, beta,
+ initial_state=h0,
+ cu_seqlens=cu_seqlens
+ )
+ """
+ if cu_seqlens is not None and q.shape[0] != 1:
+ raise ValueError(
+ f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`."
+ f"Please flatten variable-length inputs before processing."
+ )
+ if scale is None:
+ scale = k.shape[-1] ** -0.5
+ else:
+ assert scale > 0, "scale must be positive"
+ fuse_gating = A_log is not None
+ if not fuse_gating and beta is None:
+ beta = torch.ones_like(q[..., 0])
+ o, final_state = FusedRecurrentFunction.apply(
+ q,
+ k,
+ v,
+ g,
+ beta,
+ scale,
+ initial_state,
+ inplace_final_state,
+ cu_seqlens,
+ ssm_state_indices,
+ ssm_state_write_indices,
+ num_accepted_tokens,
+ use_qk_l2norm_in_kernel,
+ A_log,
+ dt_bias,
+ a_raw,
+ b_raw,
+ out,
+ )
+ return o, final_state
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py
new file mode 100644
index 0000000000..8b1d59fc63
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py
@@ -0,0 +1,30 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+# ruff: noqa: E501
+import torch
+
+import triton
+
+from .utils import tensor_cache
+
+
+@tensor_cache
+def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
+ return cu_seqlens[1:] - cu_seqlens[:-1]
+
+
+@tensor_cache
+def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
+ indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()])
+ return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)
+
+
+@tensor_cache
+def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor:
+ return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1)
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py
new file mode 100644
index 0000000000..29f892ef26
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py
@@ -0,0 +1,173 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+
+import os
+
+import torch
+
+import triton
+import triton.language as tl
+from lightllm.common.triton_utils.autotuner import autotune
+
+BT_LIST = [8, 16, 32, 64, 128]
+
+USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0"))
+
+
+@triton.jit
+def l2norm_fwd_kernel1(
+ x,
+ y,
+ D,
+ BD: tl.constexpr,
+ eps,
+):
+ i_t = tl.program_id(0)
+ x += i_t * D
+ y += i_t * D
+ # Compute mean and variance
+ cols = tl.arange(0, BD)
+ mask = cols < D
+ b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32)
+ b_var = tl.sum(b_x * b_x, axis=0)
+ b_rstd = 1 / tl.sqrt(b_var + eps)
+ # tl.store(Rstd + i_t, rstd)
+ # Normalize and apply linear transformation
+ b_y = b_x * b_rstd
+ tl.store(y + cols, b_y, mask=mask)
+
+
+@triton.jit(do_not_specialize=["NB"])
+def l2norm_fwd_kernel(
+ x,
+ y,
+ eps,
+ NB,
+ T,
+ D: tl.constexpr,
+ BT: tl.constexpr,
+ BD: tl.constexpr,
+):
+ i_t = tl.program_id(0)
+ p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
+ b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32)
+ b_var = tl.sum(b_x * b_x, axis=1)
+ b_y = b_x / tl.sqrt(b_var + eps)[:, None]
+ p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0))
+ tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1))
+
+
+@triton.jit
+def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr):
+ xoffset = tl.program_id(0) * MBLOCK
+ row_idx = xoffset + tl.arange(0, MBLOCK)[:, None]
+ xmask = row_idx < M
+ rindex = tl.arange(0, N)[None, :]
+ xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32)
+ square = tl.broadcast_to(xs * xs, [MBLOCK, N])
+ square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None]
+ rsqrt = tl.rsqrt(square_sum + eps)
+ tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask)
+
+
+def _get_l2norm_kernel1_configs():
+ return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16, 32]]
+
+
+def _get_l2norm_kernel1_static_key(x):
+ D = x.shape[-1]
+ return {"D": D}
+
+
+def _get_l2norm_kernel1_run_key(x):
+ return x.shape[0] # T
+
+
+@autotune(
+ kernel_name="l2norm_fwd_kernel1",
+ configs_gen_func=_get_l2norm_kernel1_configs,
+ static_key_func=_get_l2norm_kernel1_static_key,
+ run_key_func=_get_l2norm_kernel1_run_key,
+)
+def _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD, run_config=None):
+ if run_config is None:
+ run_config = {"num_warps": 4}
+
+ num_warps = run_config.get("num_warps", 4)
+ T = x.shape[0]
+
+ l2norm_fwd_kernel1[(T,)](x, y, eps=eps, D=D, BD=BD, num_warps=num_warps)
+
+
+def _get_l2norm_kernel_configs():
+ return [{"BT": BT, "num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST]
+
+
+def _get_l2norm_kernel_static_key(x):
+ D = x.shape[-1]
+ return {"D": D}
+
+
+def _get_l2norm_kernel_run_key(x):
+ return x.shape[0] # T
+
+
+@autotune(
+ kernel_name="l2norm_fwd_kernel",
+ configs_gen_func=_get_l2norm_kernel_configs,
+ static_key_func=_get_l2norm_kernel_static_key,
+ run_key_func=_get_l2norm_kernel_run_key,
+)
+def _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB, run_config=None):
+ if run_config is None:
+ run_config = {"BT": 32, "num_warps": 4}
+
+ BT = run_config.get("BT", 32)
+ num_warps = run_config.get("num_warps", 4)
+
+ grid = (triton.cdiv(T, BT),)
+ l2norm_fwd_kernel[grid](x, y, eps, NB=NB, T=T, D=D, BT=BT, BD=BD, num_warps=num_warps)
+
+
+def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None):
+ x_shape_og = x.shape
+ x = x.view(-1, x.shape[-1])
+ # allocate output
+ if output_dtype is None:
+ y = torch.empty_like(x)
+ else:
+ y = torch.empty_like(x, dtype=output_dtype)
+ assert y.stride(-1) == 1
+ T, D = x.shape[0], x.shape[-1]
+ # rstd = torch.empty((T,), dtype=torch.float32, device=x.device)
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D))
+ if D > BD:
+ raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
+
+ if not USE_DEFAULT_FLA_NORM:
+ MBLOCK = 32
+ # M, N = x.shape
+ l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)](
+ x,
+ y,
+ eps,
+ T,
+ D,
+ MBLOCK,
+ )
+ else:
+ if D <= 512:
+ NB = triton.cdiv(T, 2048)
+ _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB)
+ else:
+ _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD)
+
+ return y.view(x_shape_og)
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py
new file mode 100644
index 0000000000..2f69aa981d
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py
@@ -0,0 +1,65 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+
+import triton
+import triton.language as tl
+
+from .utils import is_gather_supported
+
+exp = tl.exp
+log = tl.log
+log2 = tl.log2
+
+
+@triton.jit
+def safe_exp(x):
+ """
+ Numerically stable exponential function.
+ Only applies exp to non-positive values, returns 0 for positive values.
+ This prevents numerical overflow and improves stability.
+ """
+ return exp(tl.where(x <= 0, x, float("-inf")))
+
+
+if not is_gather_supported:
+
+ @triton.jit
+ def gather(src, index, axis, _builder=None):
+ """
+ Gather operation that works when tl.gather is not supported.
+ This is a fallback implementation that returns None.
+ Just to make triton compiler happy.
+ """
+ return None
+
+else:
+ gather = tl.gather
+
+if hasattr(triton.language, "_experimental_make_tensor_descriptor"):
+ # For Triton 3.3.x
+ make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor
+elif hasattr(triton.language, "make_tensor_descriptor"):
+ # For Triton 3.4.x and later
+ make_tensor_descriptor = triton.language.make_tensor_descriptor
+else:
+ """
+ Fallback implementation when TMA is not supported.
+ Returns None to indicate TMA descriptors are unavailable.
+ Just make triton compiler happy.
+ """
+
+ @triton.jit
+ def make_tensor_descriptor(
+ base,
+ shape,
+ strides,
+ block_shape,
+ _builder=None,
+ ):
+ return None
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py
new file mode 100644
index 0000000000..b5b6cfc369
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py
@@ -0,0 +1,462 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+# ruff: noqa: E501
+
+import os
+from typing import Optional
+
+import torch
+
+import triton
+import triton.language as tl
+
+from .index import prepare_chunk_indices
+from .op import make_tensor_descriptor
+from .utils import input_guard, is_amd, is_tma_supported
+
+
+def _ensure_triton_allocator():
+ """Ensure Triton has an allocator set for kernels requiring scratch memory."""
+
+ def alloc_fn(size: int, alignment: int, stream: Optional[int]):
+ return torch.empty(size, device="cuda", dtype=torch.int8)
+
+ triton.set_allocator(alloc_fn)
+
+
+FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee")
+ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"]
+assert (
+ FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS
+), f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}"
+
+
+@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
+@triton.jit(do_not_specialize=["T"])
+def solve_tril_16x16_kernel(
+ A,
+ Ai,
+ cu_seqlens,
+ chunk_indices,
+ T,
+ H: tl.constexpr,
+ BT: tl.constexpr,
+ USE_TMA: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ DOT_PRECISION: tl.constexpr,
+):
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
+ i_b, i_h = i_bh // H, i_bh % H
+ if IS_VARLEN:
+ i_n, i_t = (
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
+ )
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ else:
+ bos, eos = i_b * T, i_b * T + T
+ o_i = tl.arange(0, 16)
+ m_A = o_i[:, None] > o_i[None, :]
+ m_I = o_i[:, None] == o_i[None, :]
+
+ A = A + (bos * H + i_h) * BT
+ Ai = Ai + (bos * H + i_h) * 16
+
+ offset = (i_t * 16) % BT
+ if not USE_TMA:
+ p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0))
+ # [16, 16]
+ b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32)
+ else:
+ desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
+ desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16])
+ b_A = desc.load([i_t * 16, offset]).to(tl.float32)
+ b_A = -tl.where(m_A, b_A, 0)
+
+ for i in range(2, min(16, T - i_t * 16)):
+ # [16]
+ b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset)
+ b_a = b_a + tl.sum(b_a[:, None] * b_A, 0)
+ b_A = tl.where((o_i == i)[:, None], b_a, b_A)
+ b_A += m_I
+ if not USE_TMA:
+ p_Ai = tl.make_block_ptr(Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0))
+ tl.store(
+ p_Ai,
+ b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ else:
+ desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+
+
+@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
+@triton.jit(do_not_specialize=["T"])
+def merge_16x16_to_32x32_inverse_kernel(
+ A,
+ Ai,
+ cu_seqlens,
+ chunk_indices,
+ T,
+ H: tl.constexpr,
+ BT: tl.constexpr,
+ USE_TMA: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ DOT_PRECISION: tl.constexpr,
+):
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
+ i_b, i_h = i_bh // H, i_bh % H
+ if IS_VARLEN:
+ i_n, i_t = (
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
+ )
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ else:
+ bos, eos = i_b * T, i_b * T + T
+
+ o_i = tl.arange(0, 16)
+ m_A = o_i[:, None] > o_i[None, :]
+ m_I = o_i[:, None] == o_i[None, :]
+ A += (bos * H + i_h) * BT
+ Ai += (bos * H + i_h) * BT
+
+ if not USE_TMA:
+ p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
+ p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
+ b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
+ b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
+ else:
+ desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
+ desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
+ b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
+ b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
+
+ # [16, 16]
+ b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
+ b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
+
+ for i in range(2, min(16, T - i_t * BT)):
+ b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
+ b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
+ b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
+ for i in range(16 + 2, min(32, T - i_t * BT)):
+ b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
+ b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
+ b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
+
+ b_Ai_11 += m_I
+ b_Ai_22 += m_I
+
+ if not USE_TMA:
+ p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
+ b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
+ else:
+ b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
+
+ b_Ai_21 = -tl.dot(
+ tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
+ b_Ai_11,
+ input_precision=DOT_PRECISION,
+ )
+
+ if not USE_TMA:
+ p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
+ p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
+ p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
+ tl.store(
+ p_Ai_11,
+ b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_22,
+ b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_21,
+ b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ else:
+ desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+
+
+@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
+@triton.jit(do_not_specialize=["T"])
+def merge_16x16_to_64x64_inverse_kernel(
+ A,
+ Ai,
+ cu_seqlens,
+ chunk_indices,
+ T,
+ H: tl.constexpr,
+ BT: tl.constexpr,
+ USE_TMA: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+ DOT_PRECISION: tl.constexpr,
+):
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
+ i_b, i_h = i_bh // H, i_bh % H
+ if IS_VARLEN:
+ i_n, i_t = (
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
+ )
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ else:
+ bos, eos = i_b * T, i_b * T + T
+
+ o_i = tl.arange(0, 16)
+ m_A = o_i[:, None] > o_i[None, :]
+ m_I = o_i[:, None] == o_i[None, :]
+ A += (bos * H + i_h) * BT
+ Ai += (bos * H + i_h) * BT
+
+ if not USE_TMA:
+ p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
+ p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
+ p_A_33 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0))
+ p_A_44 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0))
+ b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32)
+ b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32)
+ b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32)
+ b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32)
+ else:
+ desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16])
+ desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16])
+ b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32)
+ b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32)
+ b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32)
+ b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32)
+
+ # [16, 16]
+ b_Ai_11 = -tl.where(m_A, b_Ai_11, 0)
+ b_Ai_22 = -tl.where(m_A, b_Ai_22, 0)
+ b_Ai_33 = -tl.where(m_A, b_Ai_33, 0)
+ b_Ai_44 = -tl.where(m_A, b_Ai_44, 0)
+
+ for i in range(2, min(16, T - i_t * BT)):
+ b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i)
+ b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0)
+ b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11)
+ for i in range(16 + 2, min(32, T - i_t * BT)):
+ b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16)
+ b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0)
+ b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22)
+ for i in range(32 + 2, min(48, T - i_t * BT)):
+ b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32)
+ b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0)
+ b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33)
+ for i in range(48 + 2, min(64, T - i_t * BT)):
+ b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48)
+ b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0)
+ b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44)
+ b_Ai_11 += m_I
+ b_Ai_22 += m_I
+ b_Ai_33 += m_I
+ b_Ai_44 += m_I
+
+ if not USE_TMA:
+ p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
+ p_A_31 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0))
+ p_A_32 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0))
+ p_A_41 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0))
+ p_A_42 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0))
+ p_A_43 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0))
+ b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32)
+ b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32)
+ b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32)
+ b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32)
+ b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32)
+ b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32)
+ else:
+ b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32)
+ b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32)
+ b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32)
+ b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32)
+ b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32)
+ b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32)
+
+ b_Ai_21 = -tl.dot(
+ tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION),
+ b_Ai_11,
+ input_precision=DOT_PRECISION,
+ )
+ b_Ai_32 = -tl.dot(
+ tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION),
+ b_Ai_22,
+ input_precision=DOT_PRECISION,
+ )
+ b_Ai_43 = -tl.dot(
+ tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION),
+ b_Ai_33,
+ input_precision=DOT_PRECISION,
+ )
+
+ b_Ai_31 = -tl.dot(
+ b_Ai_33,
+ tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION),
+ input_precision=DOT_PRECISION,
+ )
+ b_Ai_42 = -tl.dot(
+ b_Ai_44,
+ tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION),
+ input_precision=DOT_PRECISION,
+ )
+ b_Ai_41 = -tl.dot(
+ b_Ai_44,
+ tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION)
+ + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION)
+ + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION),
+ input_precision=DOT_PRECISION,
+ )
+
+ if not USE_TMA:
+ p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0))
+ p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0))
+ p_Ai_33 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0))
+ p_Ai_44 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0))
+ p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0))
+ p_Ai_31 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0))
+ p_Ai_32 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0))
+ p_Ai_41 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0))
+ p_Ai_42 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0))
+ p_Ai_43 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0))
+ tl.store(
+ p_Ai_11,
+ b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_22,
+ b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_33,
+ b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_44,
+ b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_21,
+ b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_31,
+ b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_32,
+ b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_41,
+ b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_42,
+ b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ tl.store(
+ p_Ai_43,
+ b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"),
+ boundary_check=(0, 1),
+ )
+ else:
+ desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+ desc_o.store([i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne"))
+
+
+@input_guard
+def solve_tril(
+ A: torch.Tensor,
+ cu_seqlens: torch.Tensor | None = None,
+ output_dtype: torch.dtype = torch.float,
+) -> torch.Tensor:
+ """
+ Compute the inverse of the matrix I + A
+ A should be strictly lower triangular, i.e., A.triu() == 0.
+
+ Args:
+ A (torch.Tensor):
+ [B, T, H, BT], where BT should only be 16, 32, or 64.
+ cu_seqlens (torch.Tensor):
+ The cumulative sequence lengths of the input tensor. Default: `None`.
+ output_dtype (torch.dtype):
+ The dtype of the output tensor. Default: `torch.float`.
+ If `None`, the output dtype will be the same as the input dtype.
+
+ Returns:
+ (I + A)^-1 with the same shape as A
+ """
+ assert A.shape[-1] in [16, 32, 64]
+ output_dtype = A.dtype if output_dtype is None else output_dtype
+
+ B, T, H, BT = A.shape
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
+ NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT)
+
+ Ai = torch.zeros_like(A, dtype=output_dtype)
+ if BT == 16:
+ merge_fn = solve_tril_16x16_kernel
+ elif BT == 32:
+ merge_fn = merge_16x16_to_32x32_inverse_kernel
+ elif BT == 64:
+ merge_fn = merge_16x16_to_64x64_inverse_kernel
+
+ # Ensure Triton allocator is set for TMA kernels that require scratch memory
+ if is_tma_supported:
+ _ensure_triton_allocator()
+
+ merge_fn[NT, B * H](
+ A=A,
+ Ai=Ai,
+ cu_seqlens=cu_seqlens,
+ chunk_indices=chunk_indices,
+ T=T,
+ H=H,
+ BT=BT,
+ USE_TMA=is_tma_supported,
+ DOT_PRECISION=FLA_TRIL_PRECISION,
+ )
+ return Ai
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py
new file mode 100644
index 0000000000..cd7c2e3aeb
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py
@@ -0,0 +1,179 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+# ruff: noqa: E501
+import contextlib
+import functools
+import logging
+import os
+from collections.abc import Callable
+from enum import Enum
+from typing import Any, Literal
+
+import torch
+
+import triton
+
+logger = logging.getLogger(__name__)
+
+COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1"
+FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1"
+FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1"
+
+SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0"))
+
+
+def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
+ """
+ A decorator that caches the most recent results of a function with tensor inputs.
+
+ This decorator will store the output of the decorated function for the most recent set of input tensors.
+ The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed.
+
+ Args:
+ fn (Callable[..., torch.Tensor]):
+ The function to be decorated. It should take tensor inputs and return tensor outputs.
+
+ Returns:
+ Callable[..., torch.Tensor]:
+ A wrapped version of the input function with single-entry caching.
+ """
+
+ cache_entries: tuple[tuple | None, dict | None, Any] = []
+ cache_size = 8
+
+ @functools.wraps(fn)
+ def wrapper(*args: Any, **kwargs: Any) -> Any:
+ nonlocal cache_entries
+ for i, entry in enumerate(cache_entries):
+ last_args, last_kwargs, last_result = entry
+ if (
+ len(args) == len(last_args)
+ and len(kwargs) == len(last_kwargs)
+ and all(a is b for a, b in zip(args, last_args))
+ and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items())
+ ):
+ cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)]
+ return last_result
+
+ result = fn(*args, **kwargs)
+
+ if len(cache_entries) >= cache_size:
+ cache_entries = cache_entries[1:]
+ cache_entries.append((args, kwargs, result))
+ return result
+
+ return wrapper
+
+
+def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]:
+ """
+ A decorator to make sure all input tensors are contiguous and set the device based on input tensors.
+ """
+
+ @functools.wraps(fn)
+ def wrapper(*args, **kwargs):
+ contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args)
+ contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()}
+
+ tensor = None
+ for arg in args:
+ if isinstance(arg, torch.Tensor):
+ tensor = arg
+ break
+ if tensor is None:
+ for value in kwargs.values():
+ if isinstance(value, torch.Tensor):
+ tensor = value
+ break
+
+ if tensor is not None:
+ ctx = torch.cuda.device(tensor.device.index)
+ else:
+ ctx = contextlib.nullcontext()
+
+ with ctx:
+ return fn(*contiguous_args, **contiguous_kwargs)
+
+ return wrapper
+
+
+@functools.cache
+def get_available_device() -> str:
+ try:
+ return triton.runtime.driver.active.get_current_target().backend
+ except BaseException:
+ return "cpu"
+
+
+@functools.cache
+def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]:
+ device = get_available_device()
+ mapping = {
+ "cuda": "nvidia",
+ "hip": "amd",
+ "xpu": "intel",
+ }
+ # return the mapped value, or the original if not found
+ return mapping.get(device, device)
+
+
+# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'.
+# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs.
+# Therefore, we need to check the triton backend to determine the actual GPU vendor.
+device = "cuda"
+device_torch_lib = getattr(torch, device, None)
+device_platform = _check_platform()
+
+is_amd = device_platform == "amd"
+is_intel = device_platform == "intel"
+is_nvidia = device_platform == "nvidia"
+is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0)
+is_nvidia_hopper = is_nvidia and (
+ "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9
+)
+use_cuda_graph = True
+is_gather_supported = hasattr(triton.language, "gather")
+is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and (
+ hasattr(triton.language, "_experimental_make_tensor_descriptor")
+ or hasattr(triton.language, "make_tensor_descriptor")
+)
+
+
+def get_all_max_shared_mem():
+ try:
+ return [
+ triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"]
+ for i in range(device_torch_lib.device_count())
+ ]
+ except BaseException:
+ return [-1]
+
+
+class Backend(Enum):
+ ADA = 101376 # RTX 4090
+ AMPERE = 166912 # A100
+ HOPPER = 232448 # H100
+ DEFAULT = 102400 # Default
+
+ @classmethod
+ def get_shared_memory(cls, arch: str) -> int:
+ try:
+ return cls[arch.upper()].value
+ except KeyError:
+ return cls.DEFAULT.value
+
+
+@functools.cache
+def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool:
+ try:
+ device_shared_mem_list = get_all_max_shared_mem()
+ max_shared_memory = device_shared_mem_list[tensor_idx]
+ return max_shared_memory >= Backend.get_shared_memory(arch)
+ except Exception:
+ return False
diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py
new file mode 100644
index 0000000000..08bb00e644
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py
@@ -0,0 +1,145 @@
+# SPDX-License-Identifier: Apache-2.0
+# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
+# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang
+#
+# This file contains code copied from the flash-linear-attention project.
+# The original source code was licensed under the MIT license and included
+# the following copyright notice:
+# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
+
+# ruff: noqa: E501
+
+import torch
+
+import triton
+import triton.language as tl
+
+from .index import prepare_chunk_indices
+
+
+@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None})
+@triton.jit(do_not_specialize=["T"])
+def recompute_w_u_fwd_kernel(
+ k,
+ v,
+ beta,
+ w,
+ u,
+ A,
+ g,
+ cu_seqlens,
+ chunk_indices,
+ T,
+ H: tl.constexpr,
+ Hg: tl.constexpr,
+ K: tl.constexpr,
+ V: tl.constexpr,
+ BT: tl.constexpr,
+ BK: tl.constexpr,
+ BV: tl.constexpr,
+ IS_VARLEN: tl.constexpr,
+):
+ i_t, i_bh = tl.program_id(0), tl.program_id(1)
+ i_b, i_h = i_bh // H, i_bh % H
+ if IS_VARLEN:
+ i_n, i_t = (
+ tl.load(chunk_indices + i_t * 2).to(tl.int32),
+ tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32),
+ )
+ bos, eos = (
+ tl.load(cu_seqlens + i_n).to(tl.int32),
+ tl.load(cu_seqlens + i_n + 1).to(tl.int32),
+ )
+ T = eos - bos
+ else:
+ bos, eos = i_b * T, i_b * T + T
+ p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,))
+ p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,))
+ p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))
+ b_beta = tl.load(p_beta, boundary_check=(0,))
+ b_A = tl.load(p_A, boundary_check=(0, 1))
+ b_g = tl.exp(tl.load(p_g, boundary_check=(0,)))
+
+ for i_v in range(tl.cdiv(V, BV)):
+ p_v = tl.make_block_ptr(
+ v + (bos * H + i_h) * V,
+ (T, V),
+ (H * V, 1),
+ (i_t * BT, i_v * BV),
+ (BT, BV),
+ (1, 0),
+ )
+ p_u = tl.make_block_ptr(
+ u + (bos * H + i_h) * V,
+ (T, V),
+ (H * V, 1),
+ (i_t * BT, i_v * BV),
+ (BT, BV),
+ (1, 0),
+ )
+ b_v = tl.load(p_v, boundary_check=(0, 1))
+ b_vb = (b_v * b_beta[:, None]).to(b_v.dtype)
+ b_u = tl.dot(b_A, b_vb, allow_tf32=False)
+ tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1))
+
+ for i_k in range(tl.cdiv(K, BK)):
+ p_k = tl.make_block_ptr(
+ k + (bos * Hg + i_h // (H // Hg)) * K,
+ (T, K),
+ (Hg * K, 1),
+ (i_t * BT, i_k * BK),
+ (BT, BK),
+ (1, 0),
+ )
+ p_w = tl.make_block_ptr(
+ w + (bos * H + i_h) * K,
+ (T, K),
+ (H * K, 1),
+ (i_t * BT, i_k * BK),
+ (BT, BK),
+ (1, 0),
+ )
+ b_k = tl.load(p_k, boundary_check=(0, 1))
+ b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype)
+ b_w = tl.dot(b_A, b_kb)
+ tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1))
+
+
+def recompute_w_u_fwd(
+ k: torch.Tensor,
+ v: torch.Tensor,
+ beta: torch.Tensor,
+ g_cumsum: torch.Tensor,
+ A: torch.Tensor,
+ cu_seqlens: torch.LongTensor | None,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ B, T, Hg, K, V = *k.shape, v.shape[-1]
+ H = v.shape[-2]
+ BT = A.shape[-1]
+
+ chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None
+ NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices)
+ BK = 64
+ BV = 64
+ u = torch.empty_like(v)
+ w = k.new_empty(B, T, H, K)
+ recompute_w_u_fwd_kernel[(NT, B * H)](
+ k=k,
+ v=v,
+ beta=beta,
+ w=w,
+ u=u,
+ A=A,
+ g=g_cumsum,
+ cu_seqlens=cu_seqlens,
+ chunk_indices=chunk_indices,
+ T=T,
+ H=H,
+ Hg=Hg,
+ K=K,
+ V=V,
+ BT=BT,
+ BK=BK,
+ BV=BV,
+ )
+ return w, u
diff --git a/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py
new file mode 100644
index 0000000000..6413158a66
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fused_add_gemma_rmsnorm.py
@@ -0,0 +1,186 @@
+import torch
+
+import triton
+import triton.language as tl
+
+from lightllm.common.triton_utils.autotuner import autotune
+
+
+@triton.jit
+def _fused_add_gemma_rmsnorm_kernel(
+ x_ptr,
+ r_ptr,
+ w_ptr,
+ y_ptr,
+ x_stride0,
+ x_stride1,
+ r_stride0,
+ r_stride1,
+ y_stride0,
+ y_stride1,
+ N: tl.constexpr,
+ EPS: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+):
+ """Fused in-place residual add + Gemma RMSNorm.
+
+ For each row:
+ 1. sum = x + residual (written back to x in-place)
+ 2. rstd = 1 / sqrt(mean(sum²) + eps)
+ 3. y = sum * rstd * (w + 1.0) (Gemma-style)
+ """
+ row = tl.program_id(0)
+ x_ptr = x_ptr + row * x_stride0
+ r_ptr = r_ptr + row * r_stride0
+ y_ptr = y_ptr + row * y_stride0
+
+ # Pass 1: compute sum = x + residual, write back to x, accumulate sum² for variance
+ _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
+ r = tl.load(r_ptr + cols * r_stride1, mask=mask, other=0.0).to(tl.float32)
+ s = x + r
+ # Write sum back to x (in-place residual add)
+ tl.store(x_ptr + cols * x_stride1, s.to(x_ptr.dtype.element_ty), mask=mask)
+ _var += s * s
+
+ var = tl.sum(_var, axis=0) / N
+ rstd = 1.0 / tl.sqrt(var + EPS)
+
+ # Pass 2: normalize and apply Gemma-style linear transformation
+ for off in range(0, N, BLOCK_SIZE):
+ cols = off + tl.arange(0, BLOCK_SIZE)
+ mask = cols < N
+ # Re-read x (now contains sum); hot in L2 from the write in pass 1
+ s = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32)
+ w = tl.load(w_ptr + cols, mask=mask).to(tl.float32)
+ y = s * rstd * (w + 1.0)
+ tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask)
+
+
+def _get_fused_add_gemma_rmsnorm_configs():
+ """Generate configurations for autotuning fused add + Gemma RMSNorm kernel."""
+ configs = []
+ for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]:
+ for num_warps in [1, 2, 4, 8]:
+ configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": 1})
+ return configs
+
+
+def _get_fused_add_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor):
+ """Generate static key for caching autotuned configurations."""
+ N = x.shape[-1]
+ return {
+ "x_dtype": str(x.dtype),
+ "weight_dtype": str(w.dtype),
+ "N": N,
+ }
+
+
+@autotune(
+ kernel_name="fused_add_gemma_rmsnorm:v1",
+ configs_gen_func=_get_fused_add_gemma_rmsnorm_configs,
+ static_key_func=_get_fused_add_gemma_rmsnorm_static_key,
+ run_key_func=lambda x: x.shape[-1],
+ mutates_args=["x"],
+)
+def fused_add_gemma_rmsnorm(x, residual, w, eps, out=None, run_config: dict = None):
+ """Fused in-place residual add + Gemma RMSNorm.
+
+ x: [M, N] - modified in-place (x += residual)
+ residual: [M, N] - residual to add (will be viewed as [-1, N])
+ w: [N] - norm weight (Gemma-style: applies w + 1.0)
+ eps: float
+ out: [M, N] - output buffer (allocated if None)
+ Returns: out
+ """
+ N = x.shape[-1]
+ y = torch.empty_like(x) if out is None else out
+ x_arg = x.view(-1, N)
+ r_arg = residual.view(-1, N)
+ y_arg = y.view(-1, N)
+
+ M = x_arg.shape[0]
+
+ # Default heuristic when autotune is disabled or no config provided
+ if not run_config:
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
+ if N > BLOCK_SIZE:
+ raise RuntimeError("This fused_add_gemma_rmsnorm doesn't support feature dim >= 64KB.")
+ num_warps = min(max(BLOCK_SIZE // 256, 1), 8)
+ run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1}
+
+ BLOCK_SIZE = run_config["BLOCK_SIZE"]
+ num_warps = run_config["num_warps"]
+ num_stages = run_config["num_stages"]
+
+ _fused_add_gemma_rmsnorm_kernel[(M,)](
+ x_arg,
+ r_arg,
+ w,
+ y_arg,
+ x_stride0=x_arg.stride(0),
+ x_stride1=x_arg.stride(1),
+ r_stride0=r_arg.stride(0),
+ r_stride1=r_arg.stride(1),
+ y_stride0=y_arg.stride(0),
+ y_stride1=y_arg.stride(1),
+ N=N,
+ EPS=eps,
+ BLOCK_SIZE=BLOCK_SIZE,
+ num_warps=num_warps,
+ num_stages=num_stages,
+ )
+
+ return y
+
+
+def _fused_add_gemma_rmsnorm_torch(x, residual, weight, eps):
+ """Reference implementation for correctness testing."""
+ original_dtype = x.dtype
+ x = x.to(torch.float32)
+ residual = residual.to(torch.float32)
+ s = x + residual
+ normed = s * torch.rsqrt(s.pow(2).mean(-1, keepdim=True) + eps)
+ out = normed * (1.0 + weight.float())
+ return s.to(original_dtype), out.to(original_dtype)
+
+
+def test_fused_add_gemma_rmsnorm(M=128, N=2048, dtype=torch.bfloat16, eps=1e-5, device="cuda"):
+ """Verify fused kernel matches separate add + gemma_rmsnorm."""
+ x_shape = (M, N)
+ w_shape = (N,)
+ weight = torch.rand(w_shape, dtype=dtype, device=device)
+ x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device=device)
+ residual = 0.1 * torch.randn(x_shape, dtype=dtype, device=device)
+
+ # Clone x for reference (since fused modifies x in-place)
+ x_ref = x.clone()
+ x_fused = x.clone()
+
+ # Reference: separate add + norm
+ x_ref_sum, y_ref = _fused_add_gemma_rmsnorm_torch(x_ref, residual, weight, eps)
+
+ # Fused kernel
+ y_fused = fused_add_gemma_rmsnorm(x_fused, residual, weight, eps)
+
+ # Check x was modified in-place (x += residual)
+ print(f"Test: M={M}, N={N}, dtype={dtype}")
+ print(f" x in-place max delta: {torch.max(torch.abs(x_fused - x_ref_sum)):.6e}")
+ print(f" output max delta: {torch.max(torch.abs(y_fused - y_ref)):.6e}")
+
+ atol = 1e-2 if dtype == torch.float32 else 5e-2
+ assert torch.allclose(x_fused, x_ref_sum, atol=atol, rtol=0), "x in-place update mismatch!"
+ assert torch.allclose(y_fused, y_ref, atol=atol, rtol=0), "output mismatch!"
+ print(" PASSED")
+
+
+if __name__ == "__main__":
+ test_fused_add_gemma_rmsnorm(M=1, N=2048)
+ test_fused_add_gemma_rmsnorm(M=128, N=2048)
+ test_fused_add_gemma_rmsnorm(M=1, N=2048, dtype=torch.float16)
+ test_fused_add_gemma_rmsnorm(M=64, N=4096, dtype=torch.float32)
+ print("All tests passed!")
diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py
new file mode 100644
index 0000000000..88febaffc6
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py
@@ -0,0 +1,93 @@
+# Adapted from https://github.com/sgl-project/sglang/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py
+from typing import Optional, Tuple
+
+import torch
+import triton
+import triton.language as tl
+
+from lightllm.common.triton_utils.autotuner import autotune
+
+
+# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias)
+# beta_output = b.sigmoid()
+@triton.jit
+def fused_gdn_gating_kernel(
+ g,
+ beta_output,
+ A_log,
+ a,
+ b,
+ dt_bias,
+ stride_a_row,
+ stride_b_row,
+ NUM_HEADS: tl.constexpr,
+ beta: tl.constexpr,
+ threshold: tl.constexpr,
+ BLK_HEADS: tl.constexpr,
+):
+ i_b, i_d = tl.program_id(0), tl.program_id(1)
+ head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS)
+ off = i_b * NUM_HEADS + head_off
+ off_a = i_b * stride_a_row + head_off
+ off_b = i_b * stride_b_row + head_off
+ mask = head_off < NUM_HEADS
+ blk_A_log = tl.load(A_log + head_off, mask=mask)
+ blk_a = tl.load(a + off_a, mask=mask)
+ blk_b = tl.load(b + off_b, mask=mask)
+ blk_bias = tl.load(dt_bias + head_off, mask=mask)
+ x = blk_a.to(tl.float32) + blk_bias.to(tl.float32)
+ softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x)
+ blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x
+ tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask)
+ blk_beta_output = tl.sigmoid(blk_b.to(tl.float32))
+ tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask)
+
+
+def _get_fused_gdn_gating_configs():
+ return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [4, 8, 16, 32, 64] for nw in [1, 2, 4]]
+
+
+def _get_fused_gdn_gating_static_key(a: torch.Tensor):
+ # group by head size and input dtype
+ return {"NUM_HEADS": a.shape[1], "a_dtype": str(a.dtype)}
+
+
+@autotune(
+ kernel_name="fused_gdn_gating:v1",
+ configs_gen_func=_get_fused_gdn_gating_configs,
+ static_key_func=_get_fused_gdn_gating_static_key,
+ run_key_func=lambda a: a.shape[0],
+)
+def fused_gdn_gating(
+ A_log: torch.Tensor,
+ a: torch.Tensor,
+ b: torch.Tensor,
+ dt_bias: torch.Tensor,
+ beta: float = 1.0,
+ threshold: float = 20.0,
+ run_config: Optional[dict] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+
+ if run_config is None:
+ run_config = {"BLK_HEADS": 8, "num_warps": 1}
+
+ batch, num_heads = a.shape
+ grid = (batch, triton.cdiv(num_heads, run_config["BLK_HEADS"]))
+ g = torch.empty(batch, num_heads, dtype=torch.float32, device=a.device)
+ beta_output = torch.empty(batch, num_heads, dtype=torch.float32, device=a.device)
+ fused_gdn_gating_kernel[grid](
+ g,
+ beta_output,
+ A_log,
+ a,
+ b,
+ dt_bias,
+ a.stride(0),
+ b.stride(0),
+ num_heads,
+ beta,
+ threshold,
+ run_config["BLK_HEADS"],
+ num_warps=run_config["num_warps"],
+ )
+ return g, beta_output
diff --git a/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py
new file mode 100644
index 0000000000..f37d4911af
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/fused_qkv_gating.py
@@ -0,0 +1,163 @@
+"""
+Fused QKV projection and GDN gating computation.
+
+This kernel fuses:
+1. Linear projection (matmul with weight)
+2. Output reorganization (split and reshape)
+3. Gating computation (g and beta from a, b)
+
+This reduces kernel launches from 3 to 1 for the QKV+gating path.
+"""
+
+import torch
+import triton
+import triton.language as tl
+from typing import Tuple, Optional
+from lightllm.common.triton_utils.autotuner import autotune
+
+
+@triton.jit
+def _fused_gdn_gating_only_kernel(
+ # Output pointers
+ g_ptr,
+ beta_ptr,
+ # Input pointers
+ a_ptr,
+ b_ptr,
+ A_log_ptr,
+ dt_bias_ptr,
+ # Dimensions
+ batch_size,
+ num_heads,
+ # Constants
+ beta_const: tl.constexpr,
+ threshold: tl.constexpr,
+ BLOCK_BATCH: tl.constexpr,
+ BLOCK_HEADS: tl.constexpr,
+):
+ """
+ Fused kernel for GDN gating computation with better memory access patterns.
+
+ Computes:
+ - g = -exp(A_log) * softplus(a + dt_bias)
+ - beta = sigmoid(b)
+ """
+ pid_batch = tl.program_id(0)
+ pid_head = tl.program_id(1)
+
+ batch_offs = pid_batch * BLOCK_BATCH + tl.arange(0, BLOCK_BATCH)
+ head_offs = pid_head * BLOCK_HEADS + tl.arange(0, BLOCK_HEADS)
+
+ batch_mask = batch_offs < batch_size
+ head_mask = head_offs < num_heads
+ mask = batch_mask[:, None] & head_mask[None, :]
+
+ # Load A_log and dt_bias (broadcast across batch)
+ A_log = tl.load(A_log_ptr + head_offs, mask=head_mask, other=0.0)
+ dt_bias = tl.load(dt_bias_ptr + head_offs, mask=head_mask, other=0.0)
+
+ # Load a and b
+ offs = batch_offs[:, None] * num_heads + head_offs[None, :]
+ a = tl.load(a_ptr + offs, mask=mask, other=0.0)
+ b = tl.load(b_ptr + offs, mask=mask, other=0.0)
+
+ # Compute g = -exp(A_log) * softplus(a + dt_bias)
+ x = a.to(tl.float32) + dt_bias.to(tl.float32)
+ softplus_x = tl.where(beta_const * x <= threshold, (1.0 / beta_const) * tl.log(1.0 + tl.exp(beta_const * x)), x)
+ g = -tl.exp(A_log.to(tl.float32)) * softplus_x
+
+ # Compute beta = sigmoid(b)
+ beta_out = tl.sigmoid(b.to(tl.float32))
+
+ # Store outputs with layout [1, batch, num_heads]
+ out_offs = batch_offs[:, None] * num_heads + head_offs[None, :]
+ tl.store(g_ptr + out_offs, g.to(g_ptr.dtype.element_ty), mask=mask)
+ tl.store(beta_ptr + out_offs, beta_out.to(beta_ptr.dtype.element_ty), mask=mask)
+
+
+def _get_fused_gating_configs():
+ """Generate autotuning configurations."""
+ configs = []
+ for block_batch in [1, 4, 8, 16]:
+ for block_heads in [8, 16, 32]:
+ for num_warps in [2, 4, 8]:
+ configs.append(
+ {
+ "BLOCK_BATCH": block_batch,
+ "BLOCK_HEADS": block_heads,
+ "num_warps": num_warps,
+ }
+ )
+ return configs
+
+
+def _get_fused_gating_static_key(a: torch.Tensor):
+ return {"dtype": str(a.dtype), "num_heads": a.shape[1]}
+
+
+def _get_fused_gating_run_key(a: torch.Tensor):
+ return a.shape[0]
+
+
+@autotune(
+ kernel_name="fused_gdn_gating_v2:v1",
+ configs_gen_func=_get_fused_gating_configs,
+ static_key_func=_get_fused_gating_static_key,
+ run_key_func=_get_fused_gating_run_key,
+ mutates_args=["g", "beta"],
+)
+def fused_gdn_gating_v2(
+ a: torch.Tensor,
+ b: torch.Tensor,
+ A_log: torch.Tensor,
+ dt_bias: torch.Tensor,
+ g: torch.Tensor,
+ beta: torch.Tensor,
+ beta_const: float = 1.0,
+ threshold: float = 20.0,
+ run_config: Optional[dict] = None,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Optimized GDN gating with pre-allocated output tensors.
+
+ Args:
+ a: Input tensor [batch, num_heads]
+ b: Input tensor [batch, num_heads]
+ A_log: Log of A parameter [num_heads]
+ dt_bias: Bias for dt [num_heads]
+ g: Output tensor [1, batch, num_heads] (pre-allocated)
+ beta: Output tensor [1, batch, num_heads] (pre-allocated)
+ beta_const: Beta constant for softplus (default: 1.0)
+ threshold: Threshold for softplus approximation (default: 20.0)
+ run_config: Optional autotuning configuration
+
+ Returns:
+ Tuple of (g, beta) - same tensors passed in, now filled
+ """
+ batch_size, num_heads = a.shape
+
+ if run_config is None:
+ run_config = {"BLOCK_BATCH": 8, "BLOCK_HEADS": 16, "num_warps": 4}
+
+ grid = (
+ triton.cdiv(batch_size, run_config["BLOCK_BATCH"]),
+ triton.cdiv(num_heads, run_config["BLOCK_HEADS"]),
+ )
+
+ _fused_gdn_gating_only_kernel[grid](
+ g,
+ beta,
+ a,
+ b,
+ A_log,
+ dt_bias,
+ batch_size,
+ num_heads,
+ beta_const,
+ threshold,
+ run_config["BLOCK_BATCH"],
+ run_config["BLOCK_HEADS"],
+ num_warps=run_config["num_warps"],
+ )
+
+ return g, beta
diff --git a/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py
new file mode 100644
index 0000000000..89db5e00cb
--- /dev/null
+++ b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py
@@ -0,0 +1,174 @@
+import triton
+import triton.language as tl
+import torch
+from lightllm.common.triton_utils.autotuner import autotune
+
+
+@triton.heuristics(
+ {
+ "HAS_BIAS": lambda args: args["B"] is not None,
+ }
+)
+@triton.jit
+def gated_rmsnorm_forward_kernel(
+ X, # pointer to the input
+ Y, # pointer to the output
+ W, # pointer to the weights
+ B, # pointer to the biases
+ Z, # pointer to the other branch (required, not optional)
+ Rstd, # pointer to the 1/std
+ stride_x_row, # how much to increase the pointer when moving by 1 row
+ stride_y_row,
+ stride_z_row,
+ M, # number of rows in X
+ N, # number of columns in X
+ eps, # epsilon to avoid division by zero
+ BLOCK_N: tl.constexpr,
+ HAS_BIAS: tl.constexpr,
+ NORM_BEFORE_GATE: tl.constexpr,
+):
+ # Map the program id to the row of X and Y it should compute.
+ row = tl.program_id(0)
+ group = tl.program_id(1)
+ X += row * stride_x_row + group * N
+ Y += row * stride_y_row + group * N
+ Z += row * stride_z_row + group * N
+ Rstd += group * M
+ W += group * N
+ if HAS_BIAS:
+ B += group * N
+ # Compute variance (RMS norm doesn't use mean)
+ cols = tl.arange(0, BLOCK_N)
+ x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
+ if not NORM_BEFORE_GATE:
+ z = tl.load(Z + cols, mask=cols < N).to(tl.float32)
+ x *= z * tl.sigmoid(z)
+ # RMS norm: compute variance directly without mean subtraction
+ xbar = tl.where(cols < N, x, 0.0)
+ var = tl.sum(xbar * xbar, axis=0) / N
+ rstd = 1 / tl.sqrt(var + eps)
+ tl.store(Rstd + row, rstd)
+ # Normalize and apply linear transformation
+ mask = cols < N
+ w = tl.load(W + cols, mask=mask).to(tl.float32)
+ if HAS_BIAS:
+ b = tl.load(B + cols, mask=mask).to(tl.float32)
+ # RMS norm: normalize without mean subtraction
+ x_hat = x * rstd
+ y = x_hat * w + b if HAS_BIAS else x_hat * w
+ if NORM_BEFORE_GATE:
+ z = tl.load(Z + cols, mask=mask).to(tl.float32)
+ y *= z * tl.sigmoid(z)
+ # Write output
+ tl.store(Y + cols, y, mask=mask)
+
+
+def _get_gated_rmsnorm_configs():
+ """Generate configurations for autotuning gated RMSNorm kernel."""
+ configs = []
+ # Different BLOCK_N sizes (powers of 2)
+ for block_n in [64, 128, 256, 512, 1024, 2048, 4096]:
+ # Different number of warps
+ for num_warps in [1, 2, 4, 8]:
+ # Skip configurations that are likely to be inefficient
+ if block_n >= 2048 and num_warps > 4:
+ continue
+ if block_n <= 128 and num_warps > 2:
+ continue
+ configs.append({"BLOCK_N": block_n, "num_warps": num_warps})
+ return configs
+
+
+def _get_gated_rmsnorm_static_key(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
+ """Generate static key for caching autotuned configurations."""
+ M, N = x.shape
+ return {
+ "x_dtype": str(x.dtype),
+ "weight_dtype": str(weight.dtype),
+ "N": N,
+ "has_bias": bias is not None,
+ }
+
+
+@autotune(
+ kernel_name="gated_rmsnorm_forward:v1",
+ configs_gen_func=_get_gated_rmsnorm_configs,
+ static_key_func=_get_gated_rmsnorm_static_key,
+ run_key_func=lambda x: x.shape[0],
+)
+def gated_rmsnorm_forward(
+ x: torch.Tensor,
+ weight: torch.Tensor,
+ bias: torch.Tensor,
+ eps: float,
+ z: torch.Tensor,
+ out: torch.Tensor = None,
+ group_size: int = None,
+ norm_before_gate: bool = True,
+ run_config: dict = None,
+):
+ M, N = x.shape
+ if group_size is None:
+ group_size = N
+ assert N % group_size == 0
+ ngroups = N // group_size
+ assert x.stride(-1) == 1
+ # z is required for gated_rmsnorm
+ assert z is not None, "z cannot be None for gated_rmsnorm_forward"
+ assert z.stride(-1) == 1
+ assert z.shape == (M, N)
+ assert weight.shape == (N,)
+ assert weight.stride(-1) == 1
+ if bias is not None:
+ assert bias.stride(-1) == 1
+ assert bias.shape == (N,)
+ # allocate output
+ if out is not None:
+ assert out.shape == x.shape
+ else:
+ out = torch.empty_like(x)
+ assert out.stride(-1) == 1
+ # For RMS norm, we still need rstd for the kernel
+ rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device)
+
+ # Default heuristic when autotune is disabled or no config provided
+ if not run_config:
+ # Less than 64KB per feature: enqueue fused kernel
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
+ if group_size > BLOCK_N:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+ # heuristics for number of warps
+ num_warps = min(max(BLOCK_N // 256, 1), 8)
+ run_config = {"BLOCK_N": BLOCK_N, "num_warps": num_warps}
+
+ BLOCK_N = run_config["BLOCK_N"]
+ num_warps = run_config["num_warps"]
+
+ # Validate BLOCK_N against group_size
+ if group_size > BLOCK_N:
+ # Fall back to largest valid BLOCK_N
+ MAX_FUSED_SIZE = 65536 // x.element_size()
+ BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size))
+ if group_size > BLOCK_N:
+ raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
+
+ grid = (M, ngroups)
+ gated_rmsnorm_forward_kernel[grid](
+ x,
+ out,
+ weight,
+ bias,
+ z,
+ rstd,
+ x.stride(0),
+ out.stride(0),
+ z.stride(0),
+ M,
+ group_size,
+ eps,
+ BLOCK_N=BLOCK_N,
+ NORM_BEFORE_GATE=norm_before_gate,
+ num_warps=num_warps,
+ )
+ return out
diff --git a/lightllm/models/vit/triton_kernel/flashattention_nopad.py b/lightllm/models/vit/triton_kernel/flashattention_nopad.py
index 8428e52996..b43f8f95af 100644
--- a/lightllm/models/vit/triton_kernel/flashattention_nopad.py
+++ b/lightllm/models/vit/triton_kernel/flashattention_nopad.py
@@ -195,7 +195,8 @@ def flash_attention_v3_fwd(
False,
window_size[0],
window_size[1],
- 0.0,
+ 0, # attention_chunk
+ 0.0, # softcap
is_rotary_interleaved=False,
scheduler_metadata=None,
num_splits=1,
diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py
index c8a82d3239..8ff03f3e29 100644
--- a/lightllm/server/api_cli.py
+++ b/lightllm/server/api_cli.py
@@ -128,7 +128,18 @@ def make_argument_parser() -> argparse.ArgumentParser:
parser.add_argument(
"--tool_call_parser",
type=str,
- choices=["qwen25", "llama3", "mistral", "deepseekv3", "qwen", "deepseekv31", "deepseekv32", "glm47", "kimi_k2"],
+ choices=[
+ "qwen25",
+ "llama3",
+ "mistral",
+ "deepseekv3",
+ "qwen",
+ "deepseekv31",
+ "deepseekv32",
+ "glm47",
+ "kimi_k2",
+ "qwen3_coder",
+ ],
default=None,
help="tool call parser type",
)
@@ -167,7 +178,7 @@ def make_argument_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
- "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time"
+ "--running_max_req_size", type=int, default=256, help="the max size for forward requests in the same time"
)
parser.add_argument("--nnodes", type=int, default=1, help="the number of nodes")
parser.add_argument("--node_rank", type=int, default=0, help="the rank of the current node")
@@ -568,7 +579,13 @@ def make_argument_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--mtp_mode",
- choices=["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att", None],
+ choices=[
+ "vanilla_with_att",
+ "eagle_with_att",
+ "vanilla_no_att",
+ "eagle_no_att",
+ None,
+ ],
default=None,
help="""Supported MTP modes.
None: Disables MTP.
@@ -638,6 +655,33 @@ def make_argument_parser() -> argparse.ArgumentParser:
default=False,
help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""",
)
+ parser.add_argument(
+ "--mamba_cache_size",
+ type=int,
+ default=None,
+ help="""The size of linear attn cache. If not specified, will be calculated
+ automatically based on mamba_cache_ratio or max_total_token_num.""",
+ )
+ parser.add_argument(
+ "--mamba_cache_ratio",
+ type=lambda v: float(v)
+ if 0.0 <= (_ := float(v)) <= 1.0
+ else (_ for _ in ()).throw(
+ argparse.ArgumentTypeError(f"--mamba_cache_ratio must be between 0.0 and 1.0, got {v}")
+ ),
+ default=0.5,
+ help="""Ratio of mamba cache to total cache memory (mamba + KV).
+ Only effective when both mamba_cache_size and max_total_token_num are not set.
+ Default is 0.5 (50%% mamba cache, 50%% KV cache).
+ Example: 0.3 -> 30%% mamba, 70%% KV; 0.7 -> 70%% mamba, 30%% KV.""",
+ )
+ parser.add_argument(
+ "--mamba_ssm_data_type",
+ type=str,
+ choices=["bfloat16", "float32"],
+ default="float32",
+ help="the data type of the model weight",
+ )
parser.add_argument(
"--hardware_platform",
type=str,
diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py
index 598bd8f1f2..c4cdfa4ba3 100644
--- a/lightllm/server/api_openai.py
+++ b/lightllm/server/api_openai.py
@@ -193,6 +193,13 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req
multimodal_params_dict["images"].append({"type": "base64", "data": data})
else:
raise ValueError("Unrecognized image input.")
+ elif img.startswith("file://"):
+ # Local file path with file:// prefix
+ file_path = img[7:] # Remove "file://" prefix
+ with open(file_path, "rb") as f:
+ multimodal_params_dict["images"].append(
+ {"type": "base64", "data": base64.b64encode(f.read()).decode("utf-8")}
+ )
else:
raise ValueError(
"Unrecognized image input. Supports local path, http url, base64, and PIL.Image."
diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py
index 7f16d519a2..a38008af6f 100644
--- a/lightllm/server/build_prompt.py
+++ b/lightllm/server/build_prompt.py
@@ -46,6 +46,7 @@ async def build_prompt(request, tools) -> str:
global tokenizer
# pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别
messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages]
+
kwargs = {"conversation": messages}
if request.character_settings:
kwargs["character_settings"] = request.character_settings
@@ -57,15 +58,7 @@ async def build_prompt(request, tools) -> str:
try:
input_str = tokenizer.apply_chat_template(**kwargs, tokenize=False, add_generation_prompt=True, tools=tools)
- except:
- # This except branch will be triggered when the chosen model
- # has a different tools input format that is not compatiable
- # with openAI's apply_chat_template tool_call format, like Mistral.
- tools = [t if "function" in t else {"function": t} for t in tools]
- input_str = tokenizer.apply_chat_template(
- **kwargs,
- tokenize=True,
- add_generation_prompt=True,
- tools=tools,
- )
+ except BaseException as e:
+ logger.error(f"Failed to build prompt: {e}")
+ raise e
return input_str
diff --git a/lightllm/server/core/objs/sampling_params.py b/lightllm/server/core/objs/sampling_params.py
index 49b21c38fc..aa71558146 100644
--- a/lightllm/server/core/objs/sampling_params.py
+++ b/lightllm/server/core/objs/sampling_params.py
@@ -334,15 +334,31 @@ class SamplingParams(ctypes.Structure):
def init(self, tokenizer, **kwargs):
super().__init__()
+ # 移除kwargs中为null的参数,避免覆盖默认值
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
+
self.best_of = kwargs.get("best_of", 1)
self.n = kwargs.get("n", self.best_of)
- self.do_sample = kwargs.get("do_sample", SamplingParams._do_sample)
- self.presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty)
- self.frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty)
- self.repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty)
- self.temperature = kwargs.get("temperature", SamplingParams._temperature)
- self.top_p = kwargs.get("top_p", SamplingParams._top_p)
- self.top_k = kwargs.get("top_k", SamplingParams._top_k)
+ do_sample = kwargs.get("do_sample", SamplingParams._do_sample)
+ self.do_sample = False if do_sample is None else do_sample
+
+ presence_penalty = kwargs.get("presence_penalty", SamplingParams._presence_penalty)
+ self.presence_penalty = 0.0 if presence_penalty is None else presence_penalty
+
+ frequency_penalty = kwargs.get("frequency_penalty", SamplingParams._frequency_penalty)
+ self.frequency_penalty = 0.0 if frequency_penalty is None else frequency_penalty
+
+ repetition_penalty = kwargs.get("repetition_penalty", SamplingParams._repetition_penalty)
+ self.repetition_penalty = 1.0 if repetition_penalty is None else repetition_penalty
+
+ temperature = kwargs.get("temperature", SamplingParams._temperature)
+ self.temperature = 1.0 if temperature is None else temperature
+
+ top_p = kwargs.get("top_p", SamplingParams._top_p)
+ self.top_p = 1.0 if top_p is None else top_p
+
+ top_k = kwargs.get("top_k", SamplingParams._top_k)
+ self.top_k = -1 if top_k is None else top_k
self.ignore_eos = kwargs.get("ignore_eos", False)
self.image_max_patch_num = kwargs.get("image_max_patch_num", -1)
self.max_new_tokens = kwargs.get("max_new_tokens", 16384)
@@ -410,13 +426,35 @@ def init(self, tokenizer, **kwargs):
def load_generation_cfg(cls, weight_dir):
try:
generation_cfg = GenerationConfig.from_pretrained(weight_dir, trust_remote_code=True).to_dict()
+ # Some checkpoints store null sampling fields in generation_config.json.
+ # Keep robust numeric defaults instead of propagating None into ctypes fields.
cls._do_sample = generation_cfg.get("do_sample", False)
+ if cls._do_sample is None:
+ cls._do_sample = False
+
cls._presence_penalty = generation_cfg.get("presence_penalty", 0.0)
+ if cls._presence_penalty is None:
+ cls._presence_penalty = 0.0
+
cls._frequency_penalty = generation_cfg.get("frequency_penalty", 0.0)
+ if cls._frequency_penalty is None:
+ cls._frequency_penalty = 0.0
+
cls._repetition_penalty = generation_cfg.get("repetition_penalty", 1.0)
+ if cls._repetition_penalty is None:
+ cls._repetition_penalty = 1.0
+
cls._temperature = generation_cfg.get("temperature", 1.0)
+ if cls._temperature is None:
+ cls._temperature = 1.0
+
cls._top_p = generation_cfg.get("top_p", 1.0)
+ if cls._top_p is None:
+ cls._top_p = 1.0
+
cls._top_k = generation_cfg.get("top_k", -1)
+ if cls._top_k is None:
+ cls._top_k = -1
except:
pass
diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py
index d3dc849664..a3f30a5aaa 100644
--- a/lightllm/server/core/objs/start_args_type.py
+++ b/lightllm/server/core/objs/start_args_type.py
@@ -31,7 +31,8 @@ class StartArgs:
batch_max_tokens: Optional[int] = field(default=None)
eos_id: List[int] = field(default_factory=list)
tool_call_parser: Optional[str] = field(
- default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]}
+ default=None,
+ metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen", "qwen3_coder"]},
)
reasoning_parser: Optional[str] = field(
default=None,
@@ -54,7 +55,7 @@ class StartArgs:
},
)
chat_template: Optional[str] = field(default=None)
- running_max_req_size: int = field(default=1000)
+ running_max_req_size: int = field(default=512)
tp: int = field(default=1)
dp: int = field(default=1)
nnodes: int = field(default=1)
@@ -108,7 +109,7 @@ class StartArgs:
disable_cudagraph: bool = field(default=False)
enable_prefill_cudagraph: bool = field(default=False)
prefll_cudagraph_max_handle_token: int = field(default=512)
- graph_max_batch_size: int = field(default=256)
+ graph_max_batch_size: int = field(default=512)
graph_split_batch_size: int = field(default=32)
graph_grow_step_size: int = field(default=16)
graph_max_len_in_batch: int = field(default=0)
@@ -135,7 +136,18 @@ class StartArgs:
ep_redundancy_expert_config_path: Optional[str] = field(default=None)
auto_update_redundancy_expert: bool = field(default=False)
mtp_mode: Optional[str] = field(
- default=None, metadata={"choices": ["vanilla_with_att", "eagle_with_att", "vanilla_no_att", "eagle_no_att"]}
+ default=None,
+ metadata={
+ "choices": [
+ "vanilla_with_att",
+ "eagle_with_att",
+ "vanilla_no_att",
+ "eagle_no_att",
+ "qwen3next_vanilla",
+ "qwen3next_eagle",
+ None,
+ ]
+ },
)
mtp_draft_model_dir: Optional[str] = field(default=None)
mtp_step: int = field(default=0)
@@ -160,3 +172,8 @@ class StartArgs:
metric_port: int = field(default=None)
multinode_httpmanager_port: int = field(default=12345)
multi_level_kv_cache_port: int = field(default=None)
+
+ # hybrid attention model (Qwen3Next)
+ mamba_cache_size: Optional[int] = field(default=None)
+ mamba_cache_ratio: Optional[float] = field(default=0.5)
+ mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]})
diff --git a/lightllm/server/function_call_parser.py b/lightllm/server/function_call_parser.py
index 3a8fddf744..13aab66179 100644
--- a/lightllm/server/function_call_parser.py
+++ b/lightllm/server/function_call_parser.py
@@ -11,6 +11,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import ast
import json
import orjson
import logging
@@ -1717,6 +1718,228 @@ def parse_streaming_increment(self, new_text: str, tools: List[Tool]) -> Streami
return StreamingParseResult(normal_text="", calls=calls)
+class Qwen3CoderDetector(BaseFormatDetector):
+ """
+ Detector for Qwen3-Coder XML-style function call format.
+
+ Format Structure:
+ ```
+
+
+
+ value1
+
+
+ value2
+
+
+
+ ```
+
+ Key differences from Qwen25Detector (JSON-based):
+ - Parameters are XML key-value pairs, not JSON objects
+ - Function name is embedded in the tag attribute
+ - Values need schema-aware type conversion (string by default)
+
+ Reference: https://docs.vllm.ai/projects/recipes/en/latest/Qwen/Qwen3-Coder-480B-A35B.html
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.bot_token = ""
+ self.eot_token = ""
+ self.tool_call_separator = "\n"
+
+ # Regex patterns
+ self.tool_call_block_regex = re.compile(r"(.*?)", re.DOTALL)
+ self.function_regex = re.compile(r"||(?=)|$)", re.DOTALL
+ )
+ self._normal_text_buffer = ""
+
+ def has_tool_call(self, text: str) -> bool:
+ return " Dict:
+ """Extract parameter type configuration from tool definitions."""
+ for tool in tools:
+ if tool.function.name == func_name and tool.function.parameters:
+ params = tool.function.parameters
+ if isinstance(params, dict) and "properties" in params:
+ return params["properties"]
+ elif isinstance(params, dict):
+ return params
+ return {}
+
+ def _convert_param_value(self, value: str, param_name: str, param_config: Dict, func_name: str) -> Any:
+ """Convert parameter value based on schema type. Safe alternative to eval()."""
+ if value.lower() == "null":
+ return None
+
+ if param_name not in param_config:
+ return value
+
+ prop = param_config.get(param_name, {})
+ param_type = str(prop.get("type", "string")).strip().lower() if isinstance(prop, dict) else "string"
+
+ if param_type in ("string", "str", "enum"):
+ return value
+ elif param_type.startswith("int") or param_type == "integer":
+ try:
+ return int(value)
+ except (ValueError, TypeError):
+ return value
+ elif param_type in ("number", "float", "double"):
+ try:
+ fv = float(value)
+ return int(fv) if fv == int(fv) else fv
+ except (ValueError, TypeError):
+ return value
+ elif param_type in ("boolean", "bool"):
+ return value.lower() == "true"
+ elif param_type in ("object", "array"):
+ try:
+ return json.loads(value)
+ except (json.JSONDecodeError, TypeError, ValueError):
+ try:
+ return ast.literal_eval(value)
+ except (ValueError, SyntaxError, TypeError):
+ return value
+ return value
+
+ def _parse_function_call(self, function_str: str, tools: List[Tool]) -> Optional[ToolCallItem]:
+ """Parse a single ... block into a ToolCallItem."""
+ try:
+ end_index = function_str.index(">")
+ except ValueError:
+ return None
+
+ func_name = function_str[:end_index].strip()
+ tool_indices = self._get_tool_indices(tools)
+ if func_name not in tool_indices:
+ logger.warning(f"Model attempted to call undefined function: {func_name}")
+ return None
+
+ parameters_text = function_str[end_index + 1 :]
+ param_config = self._get_param_config(func_name, tools)
+ param_dict = {}
+
+ for match in self.parameter_regex.findall(parameters_text):
+ try:
+ idx = match.index(">")
+ except ValueError:
+ continue
+ param_name = match[:idx].strip()
+ param_value = match[idx + 1 :]
+ # Strip leading/trailing newlines from value
+ if param_value.startswith("\n"):
+ param_value = param_value[1:]
+ if param_value.endswith("\n"):
+ param_value = param_value[:-1]
+
+ param_dict[param_name] = self._convert_param_value(param_value, param_name, param_config, func_name)
+
+ return ToolCallItem(
+ tool_index=tool_indices[func_name],
+ name=func_name,
+ parameters=json.dumps(param_dict, ensure_ascii=False),
+ )
+
+ def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
+ idx = text.find(self.bot_token)
+ normal_text = text[:idx].strip() if idx != -1 else text
+
+ if " StreamingParseResult:
+ """Streaming incremental parsing for Qwen3-Coder XML tool calls."""
+ self._buffer += new_text
+ current_text = self._buffer
+
+ if not self.has_tool_call(current_text):
+ partial_len = self._ends_with_partial_token(current_text, self.bot_token)
+ if partial_len:
+ return StreamingParseResult()
+ self._buffer = ""
+ cleaned = new_text.replace(self.eot_token, "")
+ return StreamingParseResult(normal_text=cleaned)
+
+ # Check for complete tool call blocks
+ if self.eot_token in current_text:
+ result = self.detect_and_parse(current_text, tools)
+ last_end = current_text.rfind(self.eot_token)
+ if last_end != -1:
+ self._buffer = current_text[last_end + len(self.eot_token) :].lstrip()
+ else:
+ self._buffer = ""
+ self.current_tool_id = -1
+ self.current_tool_name_sent = False
+ return result
+
+ # Partial tool call - try to extract function name for early streaming
+ if not hasattr(self, "_tool_indices"):
+ self._tool_indices = self._get_tool_indices(tools)
+
+ calls = []
+ tool_call_start = current_text.find(self.bot_token)
+ if tool_call_start == -1:
+ return StreamingParseResult()
+
+ content_after = current_text[tool_call_start + len(self.bot_token) :]
+ func_prefix = "")
+ if gt_pos == -1:
+ return StreamingParseResult()
+
+ func_name = after_func[:gt_pos].strip()
+
+ if self.current_tool_id == -1:
+ self.current_tool_id = 0
+ self.prev_tool_call_arr = []
+ self.streamed_args_for_tool = [""]
+
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
+ self.prev_tool_call_arr.append({})
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
+ self.streamed_args_for_tool.append("")
+
+ if func_name and func_name in self._tool_indices and not self.current_tool_name_sent:
+ calls.append(
+ ToolCallItem(
+ tool_index=self.current_tool_id,
+ name=func_name,
+ parameters="",
+ )
+ )
+ self.current_tool_name_sent = True
+ self.prev_tool_call_arr[self.current_tool_id] = {"name": func_name, "arguments": {}}
+
+ return StreamingParseResult(normal_text="", calls=calls)
+
+
class FunctionCallParser:
"""
Parser for function/tool calls in model outputs.
@@ -1736,6 +1959,7 @@ class FunctionCallParser:
"mistral": MistralDetector,
"qwen": Qwen25Detector,
"qwen25": Qwen25Detector,
+ "qwen3_coder": Qwen3CoderDetector,
}
def __init__(self, tools: List[Tool], tool_call_parser: str):
diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py
index e28e4c93ad..2cbbb7064a 100644
--- a/lightllm/server/httpserver/manager.py
+++ b/lightllm/server/httpserver/manager.py
@@ -296,6 +296,14 @@ async def generate(
if self.pd_mode.is_P_or_NORMAL():
await multimodal_params.verify_and_preload(request)
+ # Debug logging for multimodal requests
+ if multimodal_params and multimodal_params.images:
+ logger.debug(
+ f"[MULTIMODAL_DEBUG] req_id={group_request_id}, "
+ f"num_images={len(multimodal_params.images)}, "
+ f"max_new_tokens={sampling_params.max_new_tokens}"
+ )
+
# 记录请求到达的相关信息
await self._log_req_header(request_headers, group_request_id)
# encode
diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py
new file mode 100644
index 0000000000..08f6ba3fff
--- /dev/null
+++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py
@@ -0,0 +1,173 @@
+from typing import Set, Protocol, List, Optional, Tuple
+
+import torch
+from sortedcontainers import SortedSet
+
+from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode
+from lightllm.common.mamba_cache_mem_manager.cache_manager import MambaCacheManager
+from lightllm.utils.log_utils import init_logger
+
+logger = init_logger(__name__)
+
+
+class HybridRadixCache(RadixCache):
+ def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager):
+ super().__init__(unique_name, total_token_num, rank_in_node, kv_cache_mem_manager)
+ assert hasattr(kv_cache_mem_manager, "mamba_cache_mem_manager")
+ self.buffer_mem_manager: MambaCacheManager = kv_cache_mem_manager.mamba_cache_mem_manager
+ self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.buffer_time,))
+
+ def free_radix_cache_to_get_enough_buffer(self, need_buffer_num):
+ if need_buffer_num > self.buffer_mem_manager.can_use_mem_size:
+ need_evict_buffer_num = need_buffer_num - self.buffer_mem_manager.can_use_mem_size
+
+ release_mems = []
+
+ def release_mem(mem_index):
+ release_mems.append(mem_index)
+ return
+
+ release_buffers = []
+
+ def release_buffer(buffer_idx):
+ release_buffers.append(buffer_idx)
+ return
+
+ self._evict_buffer(need_evict_buffer_num, release_buffer, release_mem)
+ self.buffer_mem_manager.free(release_buffers)
+ if len(release_mems) > 0:
+ mem_index = torch.concat(release_mems)
+ self.mem_manager.free(mem_index)
+ return
+
+ def _evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback):
+ while need_evict_buffer_num > 0:
+ node = self.evict_buffer_set.pop(0)
+ assert node.buffer_idx is not None
+ evict_buffer_callback(node.buffer_idx)
+ node.buffer_idx = None
+ need_evict_buffer_num -= 1
+ # 当一个节点的buffer_idx变为None时,事实上无法在后续进行match,
+ # 但当该节点子节点或者引用数不为0时,仍然需要保留, 否则则应该被删除
+ if node.is_leaf() and node.ref_counter == 0:
+ self.evict_tree_set.discard(node)
+ evict_token_callback(node.token_mem_index_value)
+ self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value)
+ parent_node: TreeNode = node.parent
+ parent_node.remove_child(node)
+ if parent_node.is_leaf():
+ self.evict_tree_set.add(parent_node)
+ return
+
+ def match_prefix(self, key, update_refs=False):
+ assert len(key) != 0
+ ans_value_list = []
+ tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs)
+ miss_prefix_len = 0
+ evict_token_list = []
+ kv_len = tree_node.node_prefix_total_len
+ while tree_node != self.root_node and tree_node.buffer_idx is None:
+ if tree_node.is_leaf():
+ self.evict_tree_set.discard(tree_node)
+
+ # Only update ref_counter when update_refs is True to maintain consistency
+ # with _match_prefix_helper which only increments ref_counter when update_refs=True
+ if update_refs:
+ if tree_node.ref_counter == 1:
+ self.refed_tokens_num.arr[0] -= len(tree_node.token_mem_index_value)
+ tree_node.ref_counter -= 1 # 只减少当前节点,不递归
+
+ if tree_node.is_leaf() and tree_node.ref_counter == 0:
+ evict_token_list.append(tree_node.token_mem_index_value)
+ self.tree_total_tokens_num.arr[0] -= len(tree_node.token_mem_index_value)
+ parent_node: TreeNode = tree_node.parent
+ parent_node.remove_child(tree_node)
+ if parent_node.is_leaf():
+ self.evict_tree_set.add(parent_node)
+ tree_node = parent_node
+ else:
+ if tree_node.is_leaf():
+ self.evict_tree_set.add(tree_node)
+ tree_node = tree_node.parent
+ miss_prefix_len += len(ans_value_list.pop())
+
+ if len(evict_token_list) > 0:
+ evict_token_value = torch.concat(evict_token_list)
+ self.mem_manager.free(evict_token_value)
+
+ if tree_node == self.root_node:
+ return None, kv_len - miss_prefix_len, None
+
+ update_node = tree_node
+ while update_node != self.root_node:
+ if update_node.buffer_idx is not None:
+ self.evict_buffer_set.discard(update_node)
+ update_node.update_buffer_time()
+ self.evict_buffer_set.add(update_node)
+ update_node = update_node.parent
+
+ value = torch.concat(ans_value_list)
+ return tree_node, miss_prefix_len, value
+
+ def add_buffer_idx_to_node(self, node: TreeNode, buffer_idx: int):
+ """Set buffer_idx for a node and add it to evict_buffer_set."""
+ self.evict_buffer_set.discard(node)
+ if node.is_leaf():
+ self.evict_tree_set.discard(node)
+ if node.buffer_idx is not None:
+ self.buffer_mem_manager.free([node.buffer_idx])
+ node.buffer_idx = buffer_idx
+ node.update_buffer_time()
+ self.evict_buffer_set.add(node)
+ if node.is_leaf():
+ self.evict_tree_set.add(node)
+ return
+
+ def free_radix_cache_to_get_enough_token(self, need_token_num):
+ assert self.mem_manager is not None
+ if need_token_num > self.mem_manager.can_use_mem_size:
+ need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size
+ release_mems = []
+
+ def release_mem(mem_index):
+ release_mems.append(mem_index)
+ return
+
+ release_buffers = []
+
+ def release_buffer(buffer_idx):
+ release_buffers.append(buffer_idx)
+ return
+
+ self.evict(need_evict_token_num, release_buffer, release_mem)
+ mem_index = torch.concat(release_mems)
+ self.mem_manager.free(mem_index)
+ if len(release_buffers) > 0:
+ self.buffer_mem_manager.free(release_buffers)
+ return
+
+ def evict(self, need_remove_tokens, evict_buffer_callback, evict_callback):
+ if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens:
+ assert False, f"""can not free tree tokens {need_remove_tokens},
+ tree_total_tokens_num {self.tree_total_tokens_num.arr[0]},
+ refed_tokens_num {self.refed_tokens_num.arr[0]}"""
+ num_evicted = 0
+ while num_evicted < need_remove_tokens:
+ node: TreeNode = self.evict_tree_set.pop(0)
+ assert (
+ node.ref_counter == 0 and len(node.children) == 0 and node != self.root_node
+ ), f"error evict tree node state: {node.ref_counter}, {len(node.children)}"
+ num_evicted += len(node.token_mem_index_value)
+ evict_callback(node.token_mem_index_value)
+ if node.buffer_idx is not None:
+ self.evict_buffer_set.discard(node)
+ evict_buffer_callback(node.buffer_idx)
+ node.buffer_idx = None
+ # update total token num
+ self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value)
+ parent_node: TreeNode = node.parent
+ parent_node.remove_child(node)
+ if parent_node.is_leaf():
+ self.evict_tree_set.add(parent_node)
+
+ return
diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py
index 88b099459b..9f8d78c491 100644
--- a/lightllm/server/router/dynamic_prompt/radix_cache.py
+++ b/lightllm/server/router/dynamic_prompt/radix_cache.py
@@ -31,6 +31,12 @@ def __init__(self):
self.node_value_len = 0
self.node_prefix_total_len = 0
+ # Used by hybrid attention models (e.g., Qwen3Next) to track
+ # a per-request buffer_idx alongside the token-level KV cache.
+ # Pure attention models keep buffer_idx as None.
+ self.buffer_idx = None
+ self.buffer_time = time_gen.generate_time_id()
+
def get_compare_key(self):
return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id)
@@ -78,6 +84,9 @@ def remove_child(self, child_node: "TreeNode"):
def update_time(self):
self.time_id = time_gen.generate_time_id()
+ def update_buffer_time(self):
+ self.buffer_time = time_gen.generate_time_id()
+
def is_leaf(self):
return len(self.children) == 0
@@ -103,10 +112,10 @@ class RadixCache:
unique_name 主要用于解决单机,多实列部署时的shm冲突
"""
- def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None):
+ def __init__(self, unique_name, total_token_num, rank_in_node, kv_cache_mem_manager=None):
from lightllm.common.kv_cache_mem_manager import MemoryManager
- self.mem_manager: MemoryManager = mem_manager
+ self.mem_manager: MemoryManager = kv_cache_mem_manager
self._key_dtype = torch.int64
self._value_dtype = torch.int64
@@ -359,6 +368,7 @@ def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]:
or parent_node.ref_counter != 0
or len(parent_node.children) != 1
or child_node.ref_counter != 0
+ or parent_node.buffer_idx is not None
):
return None
@@ -489,7 +499,7 @@ def _print_helper(self, node: TreeNode, indent):
" " * indent,
f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \
time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \
- node_value_len: {node.node_value_len}",
+ node_value_len: {node.node_value_len} buffer_idx: {node.buffer_idx}",
)
for _, child in node.children.items():
self._print_helper(child, indent=indent + 2)
diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py
index 0a83b101be..731fbea405 100644
--- a/lightllm/server/router/model_infer/infer_batch.py
+++ b/lightllm/server/router/model_infer/infer_batch.py
@@ -7,10 +7,11 @@
from dataclasses import dataclass, field
from typing import List, Dict, Tuple, Optional, Callable, Any
-from lightllm.common.req_manager import ReqManager
+from lightllm.common.req_manager import ReqManager, ReqManagerForMamba
from lightllm.utils.infer_utils import mark_start, mark_end
from lightllm.server.core.objs import Req, SamplingParams, FinishStatus, ShmReqManager
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode
+from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache
from lightllm.utils.log_utils import init_logger
from lightllm.server.req_id_generator import convert_sub_id_to_group_id
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
@@ -22,6 +23,9 @@
logger = init_logger(__name__)
+# Cache for mtp_range tensors to avoid repeated allocation
+_mtp_range_cache: Dict[int, torch.Tensor] = {}
+
@dataclass
class InferenceContext:
@@ -32,10 +36,15 @@ class InferenceContext:
infer_req_ids = None
vocab_size = None
cpu_embed_cache_client: Optional[CpuEmbedCacheClient] = None
+ mtp_step: int = 0
overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。
cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream
+ @property
+ def has_recurrent_state(self):
+ return self.req_manager is not None and self.req_manager.has_recurrent_state
+
def register(
self,
backend,
@@ -57,6 +66,9 @@ def register(
self.infer_req_ids = []
self.vocab_size = vocab_size
+
+ self.mtp_step = get_env_start_args().mtp_step
+
return
def init_cpu_embed_cache_client(self):
@@ -73,6 +85,31 @@ def get_cpu_kv_cache_stream(self) -> torch.cuda.Stream:
self.cpu_kv_cache_stream = torch.cuda.Stream()
return self.cpu_kv_cache_stream
+ def _alloc_and_copy_req_buffers(
+ self, req_manager: ReqManagerForMamba, radix_cache: HybridRadixCache, req_objs: List["InferReq"]
+ ) -> None:
+ if not req_objs:
+ return
+
+ if radix_cache is not None:
+ radix_cache.free_radix_cache_to_get_enough_buffer(len(req_objs) * (self.mtp_step + 1))
+
+ req_idx_gpu = torch.tensor([r.req_idx for r in req_objs], device="cuda", dtype=torch.int64)
+ req_manager.alloc_buffer_for_req(req_idx_gpu)
+
+ if radix_cache is not None:
+ fork_req_ids = [r.req_idx for r in req_objs if r.shared_kv_node is not None]
+ if fork_req_ids:
+ src_buf_ids = [r.shared_kv_node.buffer_idx for r in req_objs if r.shared_kv_node is not None]
+ req_tensor = torch.tensor(fork_req_ids, device="cuda", dtype=torch.int32)
+ src_tensor = torch.tensor(src_buf_ids, device="cuda", dtype=torch.int32)
+
+ mtp_step = req_manager.mtp_step
+ if mtp_step not in _mtp_range_cache:
+ _mtp_range_cache[mtp_step] = torch.arange(0, mtp_step + 1, dtype=torch.int32, device="cuda")
+ dst_buffers = req_manager.req_to_buffer_index[req_tensor[:, None], _mtp_range_cache[mtp_step][None, :]]
+ req_manager.buffer_mem_manager.fork_state_buffers(src_tensor, dst_buffers)
+
def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache: bool = True) -> List["InferReq"]:
req_objs = []
request_ids = []
@@ -111,6 +148,9 @@ def add_reqs(self, requests: List[Tuple[int, int, Any, int]], init_prefix_cache:
slave_req: InferReq = slave_req
slave_req.related_master_req = master_req
+ if isinstance(self.req_manager, ReqManagerForMamba):
+ self._alloc_and_copy_req_buffers(self.req_manager, self.radix_cache, req_objs)
+
return req_objs
def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
@@ -122,7 +162,8 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
# .cpu() 是 流内阻塞操作
value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu()
- prefix_len, _ = self.radix_cache.insert(key, value)
+ prefix_len, node = self.radix_cache.insert(key, value)
+
old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len])
if req.shared_kv_node is not None:
@@ -130,6 +171,42 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"):
self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
req.shared_kv_node = None
+ def free_a_req_mem_for_mamba(self, free_token_index: List, req: "InferReq") -> bool:
+ if self.radix_cache is None:
+ free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][0 : req.cur_kv_len])
+ else:
+ input_token_ids = req.get_input_token_ids()
+ key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu")
+ value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu()
+
+ prefix_len, node = self.radix_cache.insert(key, value)
+ old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len
+ free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len])
+ if req.shared_kv_node is not None:
+ assert req.shared_kv_node.node_prefix_total_len <= prefix_len
+ self.radix_cache.dec_node_ref_counter(req.shared_kv_node)
+ req.shared_kv_node = None
+
+ if node.buffer_idx is None:
+ req_to_buffer_index = self.req_manager.req_to_buffer_index
+ buffer_idx = req_to_buffer_index[req.req_idx, 0].item()
+ self.radix_cache.add_buffer_idx_to_node(node, buffer_idx)
+ # 该请求的 buffer 已经被插入到 radix cache 中,不需要手动释放
+ return False
+ return True
+
+ def _free_req_mem_and_buffers(self, free_token_index: List, free_buffer_index: List, req: "InferReq"):
+ """释放请求的 KV cache 和 buffer 内存"""
+ if self.has_recurrent_state:
+ need_free_base_buffer = self.free_a_req_mem_for_mamba(free_token_index, req)
+ req_to_buffer_index = self.req_manager.req_to_buffer_index
+ if need_free_base_buffer:
+ free_buffer_index.extend(req_to_buffer_index[req.req_idx, :].tolist())
+ elif self.mtp_step > 0:
+ free_buffer_index.extend(req_to_buffer_index[req.req_idx, 1:].tolist())
+ else:
+ self.free_a_req_mem(free_token_index, req)
+
def _save_promptcache_kvbuffer(self):
"""
save prompt cache kv buffer
@@ -151,19 +228,23 @@ def _filter(self, finished_request_ids: List[int]):
free_req_index = []
free_token_index = []
+ free_buffer_index = []
for request_id in finished_request_ids:
req: InferReq = self.requests_mapping.pop(request_id)
if self.args.diverse_mode:
req.clear_master_slave_state()
- self.free_a_req_mem(free_token_index, req)
-
+ self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req)
free_req_index.append(req.req_idx)
# logger.info(f"infer release req id {req.shm_req.request_id}")
req.shm_req.shm_infer_released = True
self.shm_req_manager.put_back_req_obj(req.shm_req)
- free_token_index = custom_cat(free_token_index)
- self.req_manager.free(free_req_index, free_token_index)
+ if len(free_token_index) != 0:
+ free_token_index = custom_cat(free_token_index)
+ self.req_manager.free(free_req_index, free_token_index)
+
+ if len(free_buffer_index) != 0 and isinstance(self.req_manager, ReqManagerForMamba):
+ self.req_manager.free_buffer(free_buffer_index)
finished_req_ids_set = set(finished_request_ids)
self.infer_req_ids = [_id for _id in self.infer_req_ids if _id not in finished_req_ids_set]
@@ -191,12 +272,15 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
if pause_reqs:
g_infer_state_lock.acquire()
+ pause_req_indices = []
free_token_index = []
+ free_buffer_index = []
for req in pause_reqs:
+ pause_req_indices.append(req.req_idx)
if self.args.diverse_mode:
# 发生暂停的时候,需要清除 diverse 模式下的主从关系
req.clear_master_slave_state()
- self.free_a_req_mem(free_token_index, req)
+ self._free_req_mem_and_buffers(free_token_index, free_buffer_index, req)
req.cur_kv_len = 0
req.shm_req.shm_cur_kv_len = req.cur_kv_len
assert req.wait_pause is True
@@ -210,13 +294,16 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool):
free_token_index = custom_cat(free_token_index)
self.req_manager.free_token(free_token_index)
+ if len(free_buffer_index) != 0 and isinstance(self.req_manager, ReqManagerForMamba):
+ self.req_manager.free_buffer(free_buffer_index)
+
g_infer_state_lock.release()
return self
def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bool, can_alloc_token_num: int):
if paused_reqs:
g_infer_state_lock.acquire()
-
+ revovered_reqs = []
for req in paused_reqs:
prefill_need_token_num = req.get_cur_total_len()
if prefill_need_token_num > can_alloc_token_num:
@@ -228,7 +315,9 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo
req.shm_req.is_paused = False
logger.debug(f"infer recover paused req id {req.req_id}")
can_alloc_token_num -= prefill_need_token_num
+ revovered_reqs.append(req)
+ self._alloc_and_copy_req_buffers(revovered_reqs)
g_infer_state_lock.release()
return
@@ -466,8 +555,8 @@ def get_input_token_ids(self):
return self.shm_req.shm_prompt_ids.arr[0 : self.get_cur_total_len()]
def get_chuncked_input_token_ids(self):
- chunked_start = self.cur_kv_len
- chunked_end = min(self.get_cur_total_len(), chunked_start + self.shm_req.chunked_prefill_size)
+ # 复用 get_chuncked_input_token_len 的逻辑,保持一致性
+ chunked_end = self.get_chuncked_input_token_len()
return self.shm_req.shm_prompt_ids.arr[0:chunked_end]
def get_chuncked_input_token_len(self):
diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py
index 8b085c45ed..08932e4e41 100644
--- a/lightllm/server/router/model_infer/mode_backend/base_backend.py
+++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py
@@ -9,7 +9,6 @@
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.log_utils import init_logger
from lightllm.models import get_model
-from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack
from lightllm.server.router.token_load import TokenLoad
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock
@@ -172,12 +171,14 @@ def init_model(self, kvargs):
self.model, self.is_multimodal = get_model(model_cfg, model_kvargs)
self.model: TpPartBaseModel = self.model # for easy typing
set_random_seed(2147483647)
+
+ radix_cache_class = self.model.get_radix_class()
self.radix_cache = (
- RadixCache(
+ radix_cache_class(
get_unique_server_name(),
self.model.mem_manager.size,
self.rank_in_node,
- mem_manager=self.model.mem_manager,
+ kv_cache_mem_manager=self.model.mem_manager,
)
if self.use_dynamic_prompt_cache
else None
@@ -287,9 +288,8 @@ def decode(self, event_pack: OverlapEventPack, decode_reqs: List[InferReq]):
raise NotImplementedError()
def init_mtp_draft_model(self, main_kvargs: dict):
- # 当前只支持 deepseekv3 模式的 mtp
self.mtp_step = self.args.mtp_step
- self.draft_models: List[Deepseek3MTPModel] = []
+ self.draft_models = []
os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1"
@@ -302,6 +302,7 @@ def init_mtp_draft_model(self, main_kvargs: dict):
for i in range(num_mtp_modules):
mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i])
+ model_type = mtp_model_cfg.get("model_type", "")
mtp_model_kvargs = {
"weight_dir": self.args.mtp_draft_model_dir[i],
"max_total_token_num": self.model.mem_manager.size,
@@ -324,21 +325,22 @@ def init_mtp_draft_model(self, main_kvargs: dict):
"mtp_previous_draft_models": self.draft_models.copy(),
}
- mtp_model_cfg, _ = PretrainedConfig.get_config_dict(self.args.mtp_draft_model_dir[i])
- if mtp_model_cfg["model_type"] == "deepseek_v3":
+ # Select MTP model class based on model type
+ model_type = mtp_model_cfg.get("model_type", "")
+ if model_type == "deepseek_v3":
assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
self.draft_models.append(Deepseek3MTPModel(mtp_model_kvargs))
- elif mtp_model_cfg["model_type"] == "qwen3_moe":
+ elif model_type == "qwen3_moe":
assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"]
self.draft_models.append(Qwen3MOEMTPModel(mtp_model_kvargs))
- elif mtp_model_cfg["model_type"] == "mistral":
+ elif model_type == "mistral":
assert self.args.mtp_mode in ["vanilla_no_att", "eagle_no_att"]
self.draft_models.append(MistralMTPModel(mtp_model_kvargs))
elif mtp_model_cfg["model_type"] == "glm4_moe_lite":
assert self.args.mtp_mode in ["vanilla_with_att", "eagle_with_att"]
self.draft_models.append(Glm4MoeLiteMTPModel(mtp_model_kvargs))
else:
- assert False, f"error mtp mode {mtp_model_cfg['model_type']}"
+ raise ValueError(f"Unsupported MTP model type: {model_type}")
self.logger.info(f"loaded mtp model class {self.draft_models[i].__class__}")
return
diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py
index 2800bf0f6b..25726b2578 100644
--- a/lightllm/server/tokenizer.py
+++ b/lightllm/server/tokenizer.py
@@ -112,6 +112,14 @@ def get_tokenizer(
tokenizer = QWen3VLTokenizer(
tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg
)
+ elif model_type in ["qwen3_5", "qwen3_5_moe"] and "vision_config" in model_cfg:
+ from transformers import AutoProcessor
+ from ..models.qwen3_5.model import QWen3_5Tokenizer
+
+ processor = AutoProcessor.from_pretrained(tokenizer_name)
+ tokenizer = QWen3_5Tokenizer(
+ tokenizer=tokenizer, image_processor=processor.image_processor, model_cfg=model_cfg
+ )
elif model_cfg.get("thinker_config") is not None:
from transformers import AutoProcessor
diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py
index 3e97f4de3e..ed4665e725 100644
--- a/lightllm/server/visualserver/model_infer/model_rpc.py
+++ b/lightllm/server/visualserver/model_infer/model_rpc.py
@@ -68,7 +68,7 @@ def exposed_init_model(self, kvargs):
self.model = (
Qwen2_5_VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
)
- elif self.model_type in ["qwen3_vl", "qwen3_vl_moe"]:
+ elif self.model_type in ["qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
self.model = (
Qwen3VisionTransformerPretrainedModel(kvargs, **model_cfg["vision_config"]).eval().bfloat16()
)
diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py
index 54f20384b7..a4fbc594bc 100644
--- a/lightllm/utils/config_utils.py
+++ b/lightllm/utils/config_utils.py
@@ -87,6 +87,22 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]:
except:
pass
+ # Qwen3.5 checkpoints can have an eos_token_id in config that differs from
+ # tokenizer.eos_token_id. In practice tokenizer.eos_token_id is the reliable
+ # stop id (<|im_end|>) for detokenization/stop behavior.
+ try:
+ config_json = get_config_json(model_path)
+ model_type = config_json.get("model_type") or config_json.get("text_config", {}).get("model_type")
+ if model_type in {"qwen3_5", "qwen3_5_text", "qwen3_5_moe", "qwen3_5_moe_text"}:
+ from transformers import AutoTokenizer
+
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False)
+ if tokenizer.eos_token_id is not None:
+ return [int(tokenizer.eos_token_id)]
+ except Exception:
+ # Fall back to config-based lookup below.
+ pass
+
eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"])
if isinstance(eos_token_id, int):
return [eos_token_id]
@@ -186,6 +202,8 @@ def has_vision_module(model_path: str) -> bool:
):
# Qwen3OmniMoeVisionTransformerPretrainedModel
return True
+ elif model_type in ["qwen3_5", "qwen3_5_moe"]:
+ return True
else:
raise Exception("unknown vision model type")
except:
diff --git a/test_gsmk.py b/test_gsmk.py
new file mode 100644
index 0000000000..78a5aa467f
--- /dev/null
+++ b/test_gsmk.py
@@ -0,0 +1,241 @@
+# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py
+import argparse
+import ast
+import json
+import os
+import re
+import time
+from concurrent.futures import ThreadPoolExecutor
+from typing import Optional
+
+import numpy as np
+import requests
+from tqdm import tqdm
+
+INVALID = -9999999
+
+
+def read_jsonl(filename: str):
+ """Read a JSONL file."""
+ with open(filename) as fin:
+ for line in fin:
+ if line.startswith("#"):
+ continue
+ yield json.loads(line)
+
+
+def dump_state_text(filename: str, states: list, mode: str = "w"):
+ """Dump program state in a text file."""
+ with open(filename, mode) as fout:
+ for i, s in enumerate(states):
+ if isinstance(s, str):
+ fout.write(f"==== {i} ====\n{s}\n")
+ else:
+ fout.write(f"==== {i} ====\n{str(s)}\n")
+
+
+def download_and_cache_file(url: str, filename: Optional[str] = None):
+ """Read and cache a file from a url."""
+ if filename is None:
+ filename = os.path.join("/tmp", url.split("/")[-1])
+
+ # Check if the cache file already exists
+ if os.path.exists(filename):
+ return filename
+
+ print(f"Downloading from {url} to {filename}")
+
+ # Stream the response to show the progress bar
+ response = requests.get(url, stream=True)
+ response.raise_for_status() # Check for request errors
+
+ # Total size of the file in bytes
+ total_size = int(response.headers.get("content-length", 0))
+ chunk_size = 1024 # Download in chunks of 1KB
+
+ # Use tqdm to display the progress bar
+ with open(filename, "wb") as file, tqdm(
+ desc="Downloading",
+ total=total_size,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as bar:
+ for chunk in response.iter_content(chunk_size=chunk_size):
+ size = file.write(chunk)
+ bar.update(size)
+
+ return filename
+
+
+def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None):
+ """Call LightLLM API for text generation."""
+ assert url is not None
+
+ data = {
+ "inputs": prompt,
+ "parameters": {
+ "temperature": temperature,
+ "max_new_tokens": max_tokens,
+ "stop_sequences": stop,
+ "repetition_penalty": 1.0,
+ "top_p": 1.0,
+ "top_k": 1,
+ },
+ }
+ res = requests.post(url, json=data)
+ assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}"
+
+ response_json = res.json()
+ if "generated_text" not in response_json:
+ raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}")
+ if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0:
+ raise ValueError(
+ "Invalid API response format. 'generated_text' should be a non-empty list, "
+ f"got: {response_json['generated_text']}"
+ )
+
+ pred = response_json["generated_text"][0]
+ return pred
+
+
+def get_one_example(lines, i, include_answer):
+ ret = "Question: " + lines[i]["question"] + "\nAnswer:"
+ if include_answer:
+ ret += " " + lines[i]["answer"]
+ return ret
+
+
+def get_few_shot_examples(lines, k):
+ ret = ""
+ for i in range(k):
+ ret += get_one_example(lines, i, True) + "\n\n"
+ return ret
+
+
+def get_answer_value(answer_str):
+ answer_str = answer_str.replace(",", "")
+ # First try to find the answer after "####" marker (GSM8K format)
+ match = re.search(r"####\s*(-?\d+)", answer_str)
+ if match:
+ try:
+ return ast.literal_eval(match.group(1))
+ except SyntaxError:
+ pass
+ # Fallback: find all numbers and take the last one
+ numbers = re.findall(r"\d+", answer_str)
+ if len(numbers) < 1:
+ return INVALID
+ try:
+ return ast.literal_eval(numbers[-1])
+ except SyntaxError:
+ return INVALID
+
+
+def parse_args():
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--parallel", type=int, default=256)
+ parser.add_argument("--host", type=str, default="http://127.0.0.1")
+ parser.add_argument("--port", type=int, default=8000)
+ parser.add_argument("--num-shots", type=int, default=5)
+ parser.add_argument("--num-questions", type=int, default=200)
+ parser.add_argument("--result-file", type=str, default="result.jsonl")
+ parser.add_argument("--data-path", type=str, default="test.jsonl")
+ return parser.parse_args()
+
+
+def main(args):
+ # LightLLM API URL
+ url = f"{args.host}:{args.port}/generate"
+
+ # Read data
+ url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl"
+ filename = download_and_cache_file(url_data)
+ lines = list(read_jsonl(filename))
+
+ # Construct prompts
+ num_questions = args.num_questions
+ num_shots = args.num_shots
+ few_shot_examples = get_few_shot_examples(lines, num_shots)
+
+ # Ensure we have enough samples and avoid data leakage
+ # Test questions should start after few-shot examples
+ max_available = len(lines) - num_shots
+ if num_questions > max_available:
+ print(
+ "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. "
+ "Using {} questions.".format(num_questions, max_available, num_shots, max_available)
+ )
+ num_questions = max_available
+
+ questions = []
+ labels = []
+ for i in range(num_shots, num_shots + num_questions):
+ questions.append(get_one_example(lines, i, False))
+ labels.append(get_answer_value(lines[i]["answer"]))
+ assert all(label != INVALID for label in labels)
+
+ states = [None] * len(labels)
+
+ # Run requests using thread pool
+ def get_one_answer(i):
+ answer = call_generate_lightllm(
+ prompt=few_shot_examples + questions[i],
+ temperature=0,
+ max_tokens=1024,
+ stop=["Question", "Assistant:", "<|separator|>", "Human:", "\n\nQuestion"],
+ url=url,
+ )
+ states[i] = answer
+
+ tic = time.perf_counter()
+ if args.parallel == 1:
+ for i in tqdm(range(len(questions))):
+ get_one_answer(i)
+ else:
+ with ThreadPoolExecutor(args.parallel) as executor:
+ list(
+ tqdm(
+ executor.map(get_one_answer, list(range(len(questions)))),
+ total=len(questions),
+ )
+ )
+
+ latency = time.perf_counter() - tic
+
+ preds = []
+ for i in range(len(states)):
+ preds.append(get_answer_value(states[i]))
+
+ # Compute accuracy
+ acc = np.mean(np.array(preds) == np.array(labels))
+ invalid = np.mean(np.array(preds) == INVALID)
+
+ # Print results
+ print(f"Accuracy: {acc:.3f}")
+ print(f"Invalid: {invalid:.3f}")
+ print(f"Latency: {latency:.3f} s")
+
+ # Dump results
+ dump_state_text("tmp_output_lightllm.txt", states)
+
+ with open(args.result_file, "a") as fout:
+ value = {
+ "task": "gsm8k",
+ "backend": "lightllm",
+ "num_gpus": 1,
+ "latency": round(latency, 3),
+ "accuracy": round(acc, 3),
+ "num_requests": args.num_questions,
+ "other": {
+ "num_questions": args.num_questions,
+ "parallel": args.parallel,
+ },
+ }
+ fout.write(json.dumps(value) + "\n")
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ main(args)