Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from lightllm.utils.log_utils import init_logger
from lightllm.utils.vllm_utils import vllm_ops
from lightllm.utils.device_utils import triton_support_tensor_descriptor
from .moe_kernel_configs import MoeGroupedGemmKernelConfig
from .moe_silu_and_mul import silu_and_mul_fwd
from .moe_sum_reduce import moe_sum_reduce
from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8
Expand Down Expand Up @@ -726,16 +725,26 @@ def grouped_matmul(
block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2]

if run_config is None:
run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config(
M=token_inputs.shape[0],
N=n,
K=k,
topk_num=topk_num,
expert_num=expert_num,
mul_routed_weight=mul_routed_weight,
use_fp8_w8a8=use_fp8_w8a8,
out_dtype=str(out.dtype),
)
if token_inputs.shape[0] <= expert_num:
run_config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"NEED_TRANS": False,
"num_warps": 4,
"num_stages": 1,
}
else:
run_config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"NEED_TRANS": False,
"num_warps": 4,
"num_stages": 1,
}

BLOCK_SIZE_M = run_config["BLOCK_SIZE_M"]
BLOCK_SIZE_N = run_config["BLOCK_SIZE_N"]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import triton
import triton.language as tl
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
from lightllm.common.triton_utils.autotuner import autotune


Expand Down Expand Up @@ -121,7 +120,10 @@ def silu_and_mul_fwd(
size_n = input.shape[-1] // 2

if not run_config:
run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype))
if size_m < 256:
run_config = {"BLOCK_M": 1, "BLOCK_N": 128, "num_warps": 1, "NUM_STAGES": 1}
else:
run_config = {"BLOCK_M": 16, "BLOCK_N": 128, "num_warps": 4, "NUM_STAGES": 5}

BLOCK_M = run_config["BLOCK_M"]
BLOCK_N = run_config["BLOCK_N"]
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import triton
import triton.language as tl
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig


@triton.jit
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import torch

import triton
import triton.language as tl
from .moe_sum_recude_config import MoeSumReduceKernelConfig
from typing import Any, Callable, Dict, Optional, Tuple
from typing import Dict
from lightllm.common.triton_utils.autotuner import autotune


Expand Down Expand Up @@ -77,9 +75,12 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict =
assert output.shape[0] == token_num and output.shape[1] == hidden_dim

if not run_config:
run_config = MoeSumReduceKernelConfig.try_to_get_best_config(
M=token_num, topk_num=topk_num, hidden_dim=hidden_dim, out_dtype=str(output.dtype)
)
run_config = {
"BLOCK_M": 1,
"BLOCK_DIM": 128,
"NUM_STAGE": 1,
"num_warps": 2,
}

BLOCK_M = run_config["BLOCK_M"]
BLOCK_DIM = run_config["BLOCK_DIM"]
Expand Down
Loading
Loading