Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
7734c21
feat: add Qwen3Next linear attention model support
sufubao Feb 19, 2026
c757b06
refactor: simplify mamba buffer copy and integrate Triton kernels
sufubao Feb 20, 2026
340d11c
fix conv3d
sufubao Feb 21, 2026
a6a2435
[draft] qwen3.5 dense
sufubao Feb 26, 2026
054035d
split dense and moe
sufubao Feb 26, 2026
01b112a
feat: add mamba_cache_ratio for automatic memory allocation
sufubao Feb 26, 2026
174757d
refactor: simplify mamba_cache_ratio to direct percentage
sufubao Feb 26, 2026
dd2516e
add H100 config
sufubao Feb 26, 2026
326ae22
refactor: align radix_cache_class with infer_state_class style
sufubao Feb 27, 2026
e996cd2
fix: add missing attention_chunk param to flashattention_nopad.py
sufubao Feb 27, 2026
5e5cdbe
refactor: clarify naming in mamba_buffer_copy
sufubao Feb 27, 2026
9cf783c
clean
sufubao Feb 27, 2026
e120edb
fix
sufubao Feb 27, 2026
f3330cf
clean
sufubao Feb 27, 2026
d030a67
split
sufubao Feb 27, 2026
e1f6129
style: apply black formatting to mamba_buffer_copy
sufubao Mar 1, 2026
74f82d1
perf: add autotune configs for mamba_buffer_copy/fork kernels on H200
sufubao Mar 1, 2026
c1ea769
refactor: rename buffer copy methods for clarity
sufubao Mar 2, 2026
b81baaa
clean the code
sufubao Mar 2, 2026
0fd0202
clean code
sufubao Mar 2, 2026
eed0a9c
qwen35 qkv improve
shihaobai Mar 6, 2026
b9a386e
code simplify
shihaobai Mar 9, 2026
86f17b6
clean code
shihaobai Mar 9, 2026
a1849e6
fix
shihaobai Mar 13, 2026
61f74ac
remove contiguous
shihaobai Mar 16, 2026
bf0f254
remove gemma rms norm config
shihaobai Mar 16, 2026
76782c2
clean code
sufubao Mar 17, 2026
fdd2052
add get_radix_class
sufubao Mar 17, 2026
733e851
fix acc of mamba cache
shihaobai Mar 17, 2026
b1f8233
fix acc of mamba cache
shihaobai Mar 17, 2026
90120b0
fix warmup
shihaobai Mar 17, 2026
4ef6091
merge main
shihaobai Mar 18, 2026
13edba2
simplify the qwen3next layer_infer
shihaobai Mar 18, 2026
ec499ce
openai api simplify
shihaobai Mar 18, 2026
3c8597d
simplify mem manager
shihaobai Mar 18, 2026
20edcc1
slime code
shihaobai Mar 19, 2026
eed9863
remove mtp of base_backend
shihaobai Mar 19, 2026
90df4f1
slime mode_backend
shihaobai Mar 19, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions lightllm/common/basemodel/basemodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
from lightllm.common.basemodel.infer_struct import InferStateInfo
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
from lightllm.common.kv_cache_mem_manager import MemoryManager
from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class
from lightllm.common.req_manager import ReqManager
Expand Down Expand Up @@ -53,6 +54,9 @@ class TpPartBaseModel:
# infer state class
infer_state_class = InferStateInfo

def get_radix_class(self):
return RadixCache

def __init__(self, kvargs):
self.args = get_env_start_args()
self.run_mode = kvargs["run_mode"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,21 @@ def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor
def _tpsp_ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
raise Exception("need to impl")

def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
def context_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)

o = self._context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)

q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
return o

def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.context_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

Expand All @@ -87,39 +88,42 @@ def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_wei
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._get_qkv(input1, infer_state, layer_weight)
input1 = None
def token_attention_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._get_o(o, infer_state, layer_weight)
if self.tp_world_size_ > 1:
all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
return o

def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.token_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

input1 = self._ffn_norm(input_embdings, infer_state, layer_weight)
ffn_out = self._ffn(input1, infer_state, layer_weight)
input1 = None
if self.tp_world_size_ > 1:
all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False)
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
input1 = None
def tpsp_context_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)

o = self._context_attention_wrapper_run(
q=q, cache_kv=cache_kv, infer_state=infer_state, layer_weight=layer_weight
)

q = None
o = self._tpsp_get_o(o, infer_state, layer_weight)
return o

def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.tpsp_context_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

Expand All @@ -129,14 +133,17 @@ def tpsp_context_forward(self, input_embdings: torch.Tensor, infer_state: InferS
input_embdings.add_(ffn_out.view(-1, self.embed_dim_))
return input_embdings

def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
q, cache_kv = self._tpsp_get_qkv(input1, infer_state, layer_weight)
input1 = None
def tpsp_token_attention_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
q, cache_kv = self._tpsp_get_qkv(input_embdings, infer_state, layer_weight)
self._post_cache_kv(cache_kv, infer_state, layer_weight)
o = self._token_attention_kernel(q, infer_state, layer_weight)
q = None
o = self._tpsp_get_o(o, infer_state, layer_weight)
return o

def tpsp_token_forward(self, input_embdings: torch.Tensor, infer_state: InferStateInfo, layer_weight):
input1 = self._att_norm(input_embdings, infer_state, layer_weight)
o = self.tpsp_token_attention_forward(input1, infer_state, layer_weight)
input_embdings.add_(o.view(-1, self.embed_dim_))
o = None

Expand Down
11 changes: 10 additions & 1 deletion lightllm/common/basemodel/layer_weights/meta_weights/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,16 @@
QKVROWNMMWeight,
COLMMWeight,
)
from .norm_weight import TpRMSNormWeight, RMSNormWeight, LayerNormWeight, NoTpGEMMANormWeight, QKRMSNORMWeight
from .norm_weight import (
TpRMSNormWeight,
RMSNormWeight,
GEMMANormWeight,
LayerNormWeight,
NoTpGEMMANormWeight,
QKRMSNORMWeight,
QKGEMMANormWeight,
)
from .embedding_weight import EmbeddingWeight, LMHeadWeight, NoTpPosEmbeddingWeight
from .att_sink_weight import TpAttSinkWeight
from .fused_moe.fused_moe_weight import FusedMoeWeight
from .parameter_weight import ParameterWeight, TpParameterWeight
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def __call__(
return self._forward(input=input, eps=eps, out=out, alloc_func=alloc_func)


class GEMMANormWeight(RMSNormWeight):
def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
if self.weight_name in weights:
self.weight.copy_(weights[self.weight_name])
self.weight += 1
self.weight.load_ok = True


class LayerNormWeight(BaseWeightTpl, PlatformAwareOp):
def __init__(self, dim: int, weight_name: str, data_type: torch.dtype, bias_name: str = None):
super().__init__(tp_rank=0, tp_world_size=1)
Expand Down Expand Up @@ -276,3 +284,23 @@ def __call__(
eps: float,
) -> None:
return self._forward(q=q, k=k, eps=eps)


class QKGEMMANormWeight(QKRMSNORMWeight):
def load_hf_weights(self, weights: Dict[str, torch.Tensor]):
if self.q_weight_name in weights:
self.q_weight.copy_(weights[self.q_weight_name])
self.q_weight += 1
self.q_weight.load_ok = True
if self.k_weight_name in weights:
self.k_weight.copy_(weights[self.k_weight_name])
self.k_weight += 1
self.k_weight.load_ok = True

def _triton_forward(self, q: torch.Tensor, k: torch.Tensor, eps: float) -> tuple:
assert q.ndim == 2 and self.q_weight.ndim == 1
assert k.ndim == 2 and self.k_weight.ndim == 1
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
# So we need to set fp32_multiply to True here.
return qk_rmsnorm_fused_forward(q=q, k=k, w_q=self.q_weight, w_k=self.k_weight, eps=eps, fp32_multiply=True)
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import torch
from typing import Dict, Optional, Tuple
from .base_weight import BaseWeightTpl


class ParameterWeight(BaseWeightTpl):
def __init__(
self,
weight_name: str,
data_type: torch.dtype,
weight_shape: Optional[Tuple[int, ...]],
bias_name: Optional[str] = None,
bias_shape: Optional[Tuple[int, ...]] = None,
):
super().__init__()
self.weight_name = weight_name
self.bias_name = bias_name
self.data_type_ = data_type
self.weight_shape = weight_shape
self.bias_shape = bias_shape
self.weight: Optional[torch.Tensor] = None
self.bias: Optional[torch.Tensor] = None
if weight_shape is not None:
self._create_weight()

def _create_weight(self):
if self.weight_shape is not None:
self.weight = torch.empty(*self.weight_shape, dtype=self.data_type_, device=self.device_id_)
self.weight.load_ok = False
if self.bias_name is not None and self.bias_shape is not None:
self.bias = torch.empty(*self.bias_shape, dtype=self.data_type_, device=self.device_id_)
self.bias.load_ok = False

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
if self.weight_name in weights:
t_weight = weights[self.weight_name]
self.weight.copy_(t_weight.to(self.data_type_))
self.weight.load_ok = True
if self.bias_name is not None and self.bias_name in weights:
t_bias = weights[self.bias_name]
self.bias.copy_(t_bias.to(self.data_type_))
self.bias.load_ok = True

def verify_load(self) -> bool:
if self.weight is not None and not getattr(self.weight, "load_ok", False):
return False
if self.bias is not None and not getattr(self.bias, "load_ok", False):
return False
return True


class TpParameterWeight(ParameterWeight):
def __init__(
self,
weight_name: str,
data_type: torch.dtype,
split_n_embed: int,
bias_name: Optional[str] = None,
weight_shape: Optional[Tuple[int, ...]] = None,
bias_shape: Optional[Tuple[int, ...]] = None,
):
self.split_n_embed = split_n_embed
# Calculate TP-split shapes if full shapes are provided
tp_weight_shape = None
tp_bias_shape = None
if weight_shape is not None:
tp_weight_shape = (split_n_embed,) + weight_shape[1:]
if bias_shape is not None:
tp_bias_shape = (split_n_embed,) + bias_shape[1:]
super().__init__(weight_name, data_type, tp_weight_shape, bias_name, tp_bias_shape)

def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None:
start = self.split_n_embed * self.tp_rank_
end = self.split_n_embed * (self.tp_rank_ + 1)

if self.weight_name in weights:
t_weight = weights[self.weight_name][start:end]
self.weight.copy_(t_weight.to(self.data_type_))
self.weight.load_ok = True
if self.bias_name is not None and self.bias_name in weights:
t_bias = weights[self.bias_name][start:end]
self.bias.copy_(t_bias.to(self.data_type_))
self.bias.load_ok = True
Loading
Loading