diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 97075e9123..638abbd6ca 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -25,7 +25,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.vllm_utils import vllm_ops from lightllm.utils.device_utils import triton_support_tensor_descriptor -from .moe_kernel_configs import MoeGroupedGemmKernelConfig from .moe_silu_and_mul import silu_and_mul_fwd from .moe_sum_reduce import moe_sum_reduce from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import per_token_group_quant_fp8 @@ -726,16 +725,26 @@ def grouped_matmul( block_size_k = expert_weights.shape[2] // expert_to_weights_scale.shape[2] if run_config is None: - run_config = MoeGroupedGemmKernelConfig.try_to_get_best_config( - M=token_inputs.shape[0], - N=n, - K=k, - topk_num=topk_num, - expert_num=expert_num, - mul_routed_weight=mul_routed_weight, - use_fp8_w8a8=use_fp8_w8a8, - out_dtype=str(out.dtype), - ) + if token_inputs.shape[0] <= expert_num: + run_config = { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": False, + "num_warps": 4, + "num_stages": 1, + } + else: + run_config = { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 8, + "NEED_TRANS": False, + "num_warps": 4, + "num_stages": 1, + } BLOCK_SIZE_M = run_config["BLOCK_SIZE_M"] BLOCK_SIZE_N = run_config["BLOCK_SIZE_N"] diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_kernel_configs.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_kernel_configs.py deleted file mode 100644 index ab6620cee7..0000000000 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_kernel_configs.py +++ /dev/null @@ -1,88 +0,0 @@ -import os -from frozendict import frozendict -from functools import lru_cache -from lightllm.common.kernel_config import KernelConfigs -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class MoeGroupedGemmKernelConfig(KernelConfigs): - kernel_name: str = "grouped_moe_gemm_kernel" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - M: int, - N: int, - K: int, - topk_num: int, - expert_num: int, - mul_routed_weight: bool, - use_fp8_w8a8: bool, - out_dtype: str, - ) -> dict: - key_params = { - "N": N, - "K": K, - "topk_num": topk_num, - "expert_num": expert_num, - "mul_routed_weight": mul_routed_weight, - "use_fp8_w8a8": use_fp8_w8a8, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - config = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] - return config - else: - if M <= expert_num: - config = { - "BLOCK_SIZE_M": 16, - "BLOCK_SIZE_N": 32, - "BLOCK_SIZE_K": 64, - "GROUP_SIZE_M": 1, - "NEED_TRANS": False, - "num_warps": 4, - "num_stages": 1, - } - else: - config = { - "BLOCK_SIZE_M": 64, - "BLOCK_SIZE_N": 64, - "BLOCK_SIZE_K": 32, - "GROUP_SIZE_M": 8, - "NEED_TRANS": False, - "num_warps": 4, - "num_stages": 1, - } - return config - - @classmethod - def save_config( - cls, - N: int, - K: int, - topk_num: int, - expert_num: int, - mul_routed_weight: bool, - use_fp8_w8a8: bool, - out_dtype: str, - config_json: dict, - ): - key_params = { - "N": N, - "K": K, - "topk_num": topk_num, - "expert_num": expert_num, - "mul_routed_weight": mul_routed_weight, - "use_fp8_w8a8": use_fp8_w8a8, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index 6f6617b556..d7bcc17743 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -2,7 +2,6 @@ import triton import triton.language as tl -from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig from lightllm.common.triton_utils.autotuner import autotune @@ -121,7 +120,10 @@ def silu_and_mul_fwd( size_n = input.shape[-1] // 2 if not run_config: - run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype)) + if size_m < 256: + run_config = {"BLOCK_M": 1, "BLOCK_N": 128, "num_warps": 1, "NUM_STAGES": 1} + else: + run_config = {"BLOCK_M": 16, "BLOCK_N": 128, "num_warps": 4, "NUM_STAGES": 5} BLOCK_M = run_config["BLOCK_M"] BLOCK_N = run_config["BLOCK_N"] diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_config.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_config.py deleted file mode 100644 index 173101b898..0000000000 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_config.py +++ /dev/null @@ -1,53 +0,0 @@ -import os -from frozendict import frozendict -from functools import lru_cache -from lightllm.common.kernel_config import KernelConfigs -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class MoeSiluAndMulKernelConfig(KernelConfigs): - kernel_name: str = "moe_silu_and_mul_kernel" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - M: int, - N: int, - out_dtype: str, - ) -> dict: - key_params = { - "N": N, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - config = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] - return config - else: - if M < 256: - config = {"BLOCK_M": 1, "BLOCK_N": 128, "num_warps": 1, "NUM_STAGES": 1} - else: - config = {"BLOCK_M": 16, "BLOCK_N": 128, "num_warps": 4, "NUM_STAGES": 5} - - return config - - @classmethod - def save_config( - cls, - N: int, - out_dtype: str, - config_json: dict, - ): - key_params = { - "N": N, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py index be3f019ba5..d2c44b2953 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py @@ -2,7 +2,6 @@ import triton import triton.language as tl -from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig @triton.jit diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_recude_config.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_recude_config.py deleted file mode 100644 index d8fce5ed93..0000000000 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_recude_config.py +++ /dev/null @@ -1,59 +0,0 @@ -import os -from frozendict import frozendict -from functools import lru_cache -from lightllm.common.kernel_config import KernelConfigs -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class MoeSumReduceKernelConfig(KernelConfigs): - kernel_name: str = "moe_sum_reduce_kernel" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - M: int, - topk_num: int, - hidden_dim: int, - out_dtype: str, - ) -> dict: - key_params = { - "topk_num": topk_num, - "hidden_dim": hidden_dim, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - config = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] - return config - else: - config = { - "BLOCK_M": 1, - "BLOCK_DIM": 128, - "NUM_STAGE": 1, - "num_warps": 2, - } - - return config - - @classmethod - def save_config( - cls, - topk_num: int, - hidden_dim: int, - out_dtype: str, - config_json: dict, - ): - key_params = { - "topk_num": topk_num, - "hidden_dim": hidden_dim, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py index 266823b19c..e16351eec8 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_sum_reduce.py @@ -1,9 +1,7 @@ import torch - import triton import triton.language as tl -from .moe_sum_recude_config import MoeSumReduceKernelConfig -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Dict from lightllm.common.triton_utils.autotuner import autotune @@ -77,9 +75,12 @@ def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = assert output.shape[0] == token_num and output.shape[1] == hidden_dim if not run_config: - run_config = MoeSumReduceKernelConfig.try_to_get_best_config( - M=token_num, topk_num=topk_num, hidden_dim=hidden_dim, out_dtype=str(output.dtype) - ) + run_config = { + "BLOCK_M": 1, + "BLOCK_DIM": 128, + "NUM_STAGE": 1, + "num_warps": 2, + } BLOCK_M = run_config["BLOCK_M"] BLOCK_DIM = run_config["BLOCK_DIM"] diff --git a/lightllm/common/basemodel/triton_kernel/quantization/bmm_scaled_fp8.py b/lightllm/common/basemodel/triton_kernel/quantization/bmm_scaled_fp8.py index 030937bb1b..2b59702a42 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/bmm_scaled_fp8.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/bmm_scaled_fp8.py @@ -2,66 +2,45 @@ import torch.nn.functional as F import triton import triton.language as tl - -from lightllm.common.kernel_config import KernelConfigs -from frozendict import frozendict -from functools import lru_cache -from typing import Dict - - -class BmmScaledFp8KernelConfig(KernelConfigs): - kernel_name: str = "bmm_scaled_fp8" - - def closest_power_2(n: int) -> int: - return 1 << (n - 1).bit_length() if n & (n - 1) else n - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - B, - M, - N, - K, - batch_size, - head_dim, - ) -> dict: - key_params = { - "B": B, - "M": M, - "N": N, - "K": K, - "out_dtype": str(torch.bfloat16), - } - finded_config = cls.get_the_config(frozendict(key_params)) - - search_keys = [batch_size, head_dim] - if finded_config: - config = finded_config - for key in search_keys: - config = config[min(config.keys(), key=lambda x: abs(int(x) - key))] - else: - config = { - "BLOCK_SIZE_M": 128, - "BLOCK_SIZE_N": 256, - "BLOCK_SIZE_K": 128, - "GROUP_SIZE_M": 8, - "num_stages": 4, - "num_warps": 8, - } - return config - - @classmethod - def save_config(cls, B, M, N, K, config_json: Dict[int, Dict[int, Dict]]): - key_params = { - "B": B, - "M": M, - "N": N, - "K": K, - "out_dtype": str(torch.bfloat16), - } - key_params = frozendict(key_params) - return cls.store_config(key_params, config_json) +from lightllm.common.triton_utils.autotuner import autotune + + +def get_test_configs(): + """Generate test configurations for autotuning.""" + configs = [] + for block_size_m in [64, 128, 256]: + for block_size_n in [64, 128, 256]: + for block_size_k in [64, 128]: + for group_size_m in [4, 8, 16]: + for num_warps in [4, 8]: + for num_stages in [2, 3, 4]: + t_config = { + "BLOCK_SIZE_M": block_size_m, + "BLOCK_SIZE_N": block_size_n, + "BLOCK_SIZE_K": block_size_k, + "GROUP_SIZE_M": group_size_m, + "num_stages": num_stages, + "num_warps": num_warps, + } + configs.append(t_config) + return configs + + +def _get_static_key(a, b, c): + """Returns static key for caching (shape parameters that don't change during run).""" + B, M, K = a.shape + _, K, N = b.shape + return { + "B": B, + "N": N, + "K": K, + "out_dtype": str(c.dtype), + } + + +def _get_run_key(a): + """Returns run-time key for indexing configs (the varying dimension).""" + return a.shape[1] # M dimension @triton.jit @@ -142,7 +121,28 @@ def bmm_scaled_fp8_kernel( tl.store(c_ptrs, c, mask=c_mask) -def bmm_scaled_fp8(a, a_scale, b, b_scale, c, **run_config): +@autotune( + kernel_name="bmm_scaled_fp8:v1", + configs_gen_func=get_test_configs, + static_key_func=_get_static_key, + run_key_func=_get_run_key, + mutates_args=["c"], +) +def bmm_scaled_fp8(a, a_scale, b, b_scale, c, run_config=None): + """Batched matrix multiplication with FP8 scaling. + + Args: + a: Input tensor A with shape [batch, M, K] in FP8 format. + a_scale: Scale tensor for A with shape [batch, M, 1]. + b: Input tensor B with shape [batch, K, N] in FP8 format. + b_scale: Scale tensor for B with shape [batch, N, 1]. + c: Output tensor with shape [batch, M, N]. + run_config: Optional config dict with BLOCK_SIZE_M, BLOCK_SIZE_N, + BLOCK_SIZE_K, GROUP_SIZE_M, num_stages, num_warps. + + Returns: + torch.Tensor: The output tensor c. + """ assert a.shape[0] == b.shape[0], "Incompatible dimensions" assert c.shape[0] == b.shape[0], "Incompatible dimensions" assert a.shape[2] == b.shape[1], "Incompatible dimensions" @@ -152,15 +152,14 @@ def bmm_scaled_fp8(a, a_scale, b, b_scale, c, **run_config): HEAD = a.shape[0] if not run_config: - M2 = BmmScaledFp8KernelConfig.closest_power_2(M) - run_config = BmmScaledFp8KernelConfig.try_to_get_best_config( - B=HEAD, - M=M2, - N=N, - K=K, - batch_size=M2, - head_dim=N, - ) + run_config = { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 8, + "num_stages": 4, + "num_warps": 8, + } grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py index d3646ac7bb..30e5a59248 100644 --- a/lightllm/models/deepseek2/triton_kernel/rotary_emb.py +++ b/lightllm/models/deepseek2/triton_kernel/rotary_emb.py @@ -134,16 +134,11 @@ def rotary_emb_fwd(q, k, cos, sin, run_config=None): assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" assert triton.next_power_of_2(head_dim) == head_dim - from .rotary_emb_config import DeepseekV3RotaryKernelConfig - if not run_config: - run_config = DeepseekV3RotaryKernelConfig.try_to_get_best_config( - M=total_len, - Q_HEAD_NUM=head_num_q, - K_HEAD_NUM=head_num_k, - HEAD_DIM=head_dim, - dtype=str(q.dtype), - ) + if total_len <= 256: + run_config = {"BLOCK_SEQ": 1, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1, "HEAD_PARALLEL_NUM": 1} + else: + run_config = {"BLOCK_SEQ": 16, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1, "HEAD_PARALLEL_NUM": 1} BLOCK_SEQ = run_config["BLOCK_SEQ"] HEAD_PARALLEL_NUM = run_config["HEAD_PARALLEL_NUM"] diff --git a/lightllm/models/deepseek2/triton_kernel/rotary_emb_config.py b/lightllm/models/deepseek2/triton_kernel/rotary_emb_config.py deleted file mode 100644 index 9ea5825957..0000000000 --- a/lightllm/models/deepseek2/triton_kernel/rotary_emb_config.py +++ /dev/null @@ -1,61 +0,0 @@ -import os -from frozendict import frozendict -from functools import lru_cache -from lightllm.common.kernel_config import KernelConfigs -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class DeepseekV3RotaryKernelConfig(KernelConfigs): - kernel_name: str = "deepseek_v3_rotary_emb_kernel" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - M: int, - Q_HEAD_NUM: int, - K_HEAD_NUM: int, - HEAD_DIM: int, - dtype: str, - ) -> dict: - key_params = { - "Q_HEAD_NUM": Q_HEAD_NUM, - "K_HEAD_NUM": K_HEAD_NUM, - "HEAD_DIM": HEAD_DIM, - "dtype": str(dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - config = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] - return config - else: - if M <= 256: - config = {"BLOCK_SEQ": 1, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1, "HEAD_PARALLEL_NUM": 1} - else: - config = {"BLOCK_SEQ": 16, "NUM_STAGE": 1, "num_warps": 1, "num_stages": 1, "HEAD_PARALLEL_NUM": 1} - - return config - - @classmethod - def save_config( - cls, - Q_HEAD_NUM: int, - K_HEAD_NUM: int, - HEAD_DIM: int, - dtype: str, - config_json: dict, - ): - key_params = { - "Q_HEAD_NUM": Q_HEAD_NUM, - "K_HEAD_NUM": K_HEAD_NUM, - "HEAD_DIM": HEAD_DIM, - "dtype": str(dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) diff --git a/lightllm/models/qwen2_vl/triton_kernel/mrope.py b/lightllm/models/qwen2_vl/triton_kernel/mrope.py index 5aed658626..86df522ef7 100644 --- a/lightllm/models/qwen2_vl/triton_kernel/mrope.py +++ b/lightllm/models/qwen2_vl/triton_kernel/mrope.py @@ -1,72 +1,14 @@ -import time import torch import itertools import triton import triton.language as tl from typing import Optional -from frozendict import frozendict -from functools import lru_cache -from lightllm.common.kernel_config import KernelConfigs from lightllm.utils.log_utils import init_logger from lightllm.common.triton_utils.autotuner import autotune logger = init_logger(__name__) -class MropeTritonFusedKernelConfig(KernelConfigs): - kernel_name: str = "mrope_triton_fused_kernel" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - M: int, - Q_HEAD_NUM: int, - K_HEAD_NUM: int, - HEAD_DIM: int, - dtype: str, - ) -> dict: - key_params = { - "Q_HEAD_NUM": Q_HEAD_NUM, - "K_HEAD_NUM": K_HEAD_NUM, - "HEAD_DIM": HEAD_DIM, - "dtype": str(dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - config = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] - return config - else: - if M <= 256: - config = {"num_warps": 1, "num_stages": 1} - else: - config = {"num_warps": 1, "num_stages": 1} - - return config - - @classmethod - def save_config( - cls, - Q_HEAD_NUM: int, - K_HEAD_NUM: int, - HEAD_DIM: int, - dtype: str, - config_json: dict, - ): - key_params = { - "Q_HEAD_NUM": Q_HEAD_NUM, - "K_HEAD_NUM": K_HEAD_NUM, - "HEAD_DIM": HEAD_DIM, - "dtype": str(dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) - - @triton.jit def _mrope_triton_fused_kernel( q, @@ -200,13 +142,8 @@ def mrope_triton_fused( num_tokens = q.shape[0] if not run_config: - run_config = MropeTritonFusedKernelConfig.try_to_get_best_config( - M=num_tokens, - Q_HEAD_NUM=head_num_q, - K_HEAD_NUM=head_num_k, - HEAD_DIM=head_dim, - dtype=str(q.dtype), - ) + run_config = {"num_warps": 1, "num_stages": 1} + num_stages = run_config["num_stages"] num_warps = run_config["num_warps"] diff --git a/test/kernel/deepseekv2_bmm_scaled_fp8_tuning.py b/test/kernel/deepseekv2_bmm_scaled_fp8_tuning.py deleted file mode 100644 index d029c427bb..0000000000 --- a/test/kernel/deepseekv2_bmm_scaled_fp8_tuning.py +++ /dev/null @@ -1,94 +0,0 @@ -import triton -import torch -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -from lightllm.utils.tuning_utils import mp_tuning, set_seed, tuning_configs -import sys -import os - -from lightllm.common.basemodel.triton_kernel.bmm_scaled_fp8 import bmm_scaled_fp8, BmmScaledFp8KernelConfig - - -@torch.no_grad() -def test_func( - B, - M, - N, - K, - dtype, - test_count: int = 20, - **run_config, -): - set_seed() - - a_scale = torch.randn([B, M, 1], device="cuda", dtype=dtype) - b_scale = torch.randn([B, N, 1], device="cuda", dtype=dtype) - a = torch.randn([B, M, K], device="cuda", dtype=dtype) - b = torch.randn([B, K, N], device="cuda", dtype=dtype) - c = torch.zeros([B, M, N], device="cuda", dtype=dtype) - a = a.to(torch.float8_e4m3fn) - b = b.to(torch.float8_e4m3fn).transpose(1, 2).contiguous().transpose(1, 2) - fn = lambda: bmm_scaled_fp8(a, a_scale, b, b_scale, c, **run_config) - cost_time = triton.testing.do_bench_cudagraph(fn, rep=test_count) - - logger.info(f"bf16 {B, M, N, K} cost time: {cost_time} ms") - return cost_time - - -def get_test_configs(split_id, split_count, **kwargs): - index = 0 - for block_size_m in [64, 128, 256]: - for block_size_n in [64, 128, 256]: - for block_size_k in [64, 128]: - for group_size_m in [4, 8, 16]: - for num_warps in [4, 8]: - for num_stages in [2, 3, 4]: - t_config = { - "BLOCK_SIZE_M": block_size_m, - "BLOCK_SIZE_N": block_size_n, - "BLOCK_SIZE_K": block_size_k, - "GROUP_SIZE_M": group_size_m, - "num_stages": num_stages, - "num_warps": num_warps, - } - if index % split_count == split_id: - yield t_config - index += 1 - - -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") - import collections - - store_json_ans = collections.defaultdict(dict) - for batch_size in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512]: - for head_dim in [128, 512]: - k = 128 if head_dim == 512 else 512 - test_func_args = { - "B": 16, - "M": batch_size, - "N": head_dim, - "K": k, - "dtype": torch.bfloat16, - "test_count": 20, - } - ans = mp_tuning( - tuning_configs, - { - "test_func": test_func, - "test_func_args": test_func_args, - "get_test_configs_func": get_test_configs, - }, - ) - store_json_ans[batch_size][head_dim] = ans - BmmScaledFp8KernelConfig.save_config( - B=16, - M=batch_size, - N=head_dim, - K=k, - config_json=store_json_ans, - ) - - pass diff --git a/test/kernel/fuse_moe_tuning.py b/test/kernel/fuse_moe_tuning.py deleted file mode 100644 index 69fc4e0c80..0000000000 --- a/test/kernel/fuse_moe_tuning.py +++ /dev/null @@ -1,501 +0,0 @@ -import os -import argparse -import torch -import time -import torch.multiprocessing as mp -from lightllm.common.fused_moe.grouped_fused_moe import fused_experts_impl, moe_align, moe_align1, grouped_matmul -from typing import List -from lightllm.utils.log_utils import init_logger -from transformers import AutoConfig -import torch.nn.functional as F - -logger = init_logger(__name__) - - -def set_seed(): - import torch - import random - import numpy as np - - seed = 42 - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - return - - -def quantize_moe(weight): - try: - HAS_VLLM = True - from lightllm.common.vllm_kernel import _custom_ops as ops - except: - HAS_VLLM = False - - assert HAS_VLLM - - num_experts = weight.shape[0] - qweights = [] - weight_scales = [] - qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda() - for i in range(num_experts): - qweight, weight_scale = ops.scaled_fp8_quant( - weight[i].contiguous().cuda(), scale=None, use_per_token_if_dynamic=False - ) - qweights[i] = qweight - weight_scales.append(weight_scale) - weight_scale = torch.cat(weight_scales, dim=0).reshape(-1) - return qweights, weight_scale - - -@torch.no_grad() -def test_kernel( - expert_num: int, - m: int, - n: int, - k: int, - topk: int, - dtype: torch.dtype, - test_count: int, - use_fp8_w8a8: bool, - is_up: bool, - block_shape, - num_fused_shared_experts: int, - **config, -): - set_seed() - input_tuples = [] - - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1_scale = w2_scale = None - if num_fused_shared_experts > 0: - expert_num += num_fused_shared_experts - - if use_fp8_w8a8: - init_dtype = dtype - w1 = torch.randn(expert_num, 2 * n, k, dtype=init_dtype).cuda() - w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=init_dtype).cuda() - w1 = w1.to(torch.float8_e4m3fn) - w2 = w2.to(torch.float8_e4m3fn) - - if block_shape is None: - w1_scale = torch.randn(expert_num, dtype=torch.float32).cuda() - w2_scale = torch.randn(expert_num, dtype=torch.float32).cuda() - else: - block_n, block_k = block_shape[0], block_shape[1] - n_tiles_w1 = (2 * n + block_n - 1) // block_n - n_tiles_w2 = (k + block_n - 1) // block_n - k_tiles_w1 = (k + block_k - 1) // block_k - k_tiles_w2 = (2 * n // 2 + block_k - 1) // block_k - w1_scale = torch.rand((expert_num, n_tiles_w1, k_tiles_w1), dtype=torch.float32).cuda() - w2_scale = torch.rand((expert_num, n_tiles_w2, k_tiles_w2), dtype=torch.float32).cuda() - else: - w1 = torch.randn(expert_num, 2 * n, k, dtype=dtype).cuda() - w2 = torch.randn(expert_num, k, 2 * n // 2, dtype=dtype).cuda() - - rnd_logics = torch.randn(m, expert_num - num_fused_shared_experts, device="cuda") - topk_values, topk_ids = torch.topk(rnd_logics, topk, dim=1) - if num_fused_shared_experts > 0: - # 存在融合共享专家的时候,需要pad 共享专家对应的id 到topk_ids 中 - pad_topk_ids = ( - torch.arange( - start=expert_num - num_fused_shared_experts, end=expert_num, step=1, dtype=topk_ids.dtype, device="cuda" - ) - .view(1, num_fused_shared_experts) - .repeat(topk_ids.shape[0], 1) - ) - topk_ids = torch.cat([topk_ids, pad_topk_ids], dim=1) - topk_weights = torch.randn((m, topk + num_fused_shared_experts), device="cuda", dtype=dtype) / 10 - - expert_to_tokens = torch.empty( - (expert_num, (topk + num_fused_shared_experts) * m), dtype=torch.int32, device="cuda" - ) - expert_to_weights = torch.empty( - (expert_num, (topk + num_fused_shared_experts) * m), dtype=torch.float32, device="cuda" - ) - moe_align(topk_ids=topk_ids, out=expert_to_tokens) - expert_to_token_num = torch.empty((expert_num,), dtype=torch.int32, device="cuda") - moe_align1( - expert_to_tokens, topk_weights, expert_to_weights, expert_to_token_num, topk=topk + num_fused_shared_experts - ) - - out1 = torch.zeros((m * (topk + num_fused_shared_experts), 2 * n), dtype=torch.bfloat16, device="cuda") - down_in = torch.zeros((m * (topk + num_fused_shared_experts), n), dtype=torch.bfloat16, device="cuda") - out2 = torch.zeros((m * (topk + num_fused_shared_experts), k), dtype=torch.bfloat16, device="cuda") - - for _ in range(test_count): - input_tuples.append( - ( - a.clone(), - w1.clone(), - w2.clone(), - w1_scale.clone() if w1_scale is not None else None, - w2_scale.clone() if w2_scale is not None else None, - topk_ids.clone(), - topk_weights.clone(), - out1.clone(), - out2.clone(), - down_in.clone(), - ) - ) - - if is_up: - grouped_matmul( - topk_ids.numel(), - a, - None, - expert_to_token_num, - expert_to_tokens, - expert_to_weights=expert_to_weights, - expert_weights=w1, - expert_to_weights_scale=w1_scale, - topk_num=topk, - out=out1, - mul_routed_weight=False, - use_fp8_w8a8=use_fp8_w8a8, - run_config=config, - ) - else: - grouped_matmul( - topk_ids.numel(), - down_in, - None, - expert_to_token_num, - expert_to_tokens, - expert_to_weights=expert_to_weights, - expert_weights=w2, - expert_to_weights_scale=w2_scale, - topk_num=1, - out=out2, - mul_routed_weight=True, - use_fp8_w8a8=use_fp8_w8a8, - run_config=config, - ) - - graph = torch.cuda.CUDAGraph() - - with torch.cuda.graph(graph): - for index in range(test_count): - a, w1, w2, w1_scale, w2_scale, topk_ids, topk_weights, out1, out2, down_in = input_tuples[index] - if is_up: - grouped_matmul( - topk_ids.numel(), - a, - None, - expert_to_token_num, - expert_to_tokens, - expert_to_weights=expert_to_weights, - expert_weights=w1, - expert_to_weights_scale=w1_scale, - topk_num=topk, - out=out1, - mul_routed_weight=False, - use_fp8_w8a8=use_fp8_w8a8, - run_config=config, - ) - else: - grouped_matmul( - topk_ids.numel(), - down_in, - None, - expert_to_token_num, - expert_to_tokens, - expert_to_weights=expert_to_weights, - expert_weights=w2, - expert_to_weights_scale=w2_scale, - topk_num=1, - out=out2, - mul_routed_weight=True, - use_fp8_w8a8=use_fp8_w8a8, - run_config=config, - ) - - graph.replay() - - torch.cuda.synchronize() - start = time.time() - graph.replay() - torch.cuda.synchronize() - - cost_time = (time.time() - start) * 1000 - - logger.info(str(config)) - logger.info(f"bf16 {m} cost time: {cost_time} ms") - return cost_time - - -def worker( - expert_num: int, - m: int, - n: int, - k: int, - topk: int, - dtype: torch.dtype, - test_count: int, - use_fp8_w8a8: bool, - is_up: bool, - block_shape, - num_fused_shared_experts: int, - test_configs, - queue, -): - try: - for index in range(len(test_configs)): - cost_time = test_kernel( - expert_num=expert_num, - m=m, - n=n, - k=k, - topk=topk, - dtype=dtype, - test_count=test_count, - use_fp8_w8a8=use_fp8_w8a8, - is_up=is_up, - block_shape=block_shape, - num_fused_shared_experts=num_fused_shared_experts, - **test_configs[index], - ) - queue.put(cost_time) # Put result in queue - - except Exception as ex: - logger.error(str(ex)) - logger.exception(str(ex)) - import sys - - sys.exit(-1) - pass - - -def get_test_configs(split_id, split_count): - index = 0 - for num_stages in [ - 1, - 2, - 3, - 4, - 5, - ]: - for GROUP_SIZE_M in [ - 1, - 2, - 4, - ]: - for num_warps in [ - 2, - 4, - 8, - ]: - for BLOCK_SIZE_M in [16, 32, 64, 128]: - for BLOCK_SIZE_N in [32, 64, 128]: - for BLOCK_SIZE_K in [32, 64, 128]: - t_config = { - "BLOCK_SIZE_M": BLOCK_SIZE_M, - "BLOCK_SIZE_N": BLOCK_SIZE_N, - "BLOCK_SIZE_K": BLOCK_SIZE_K, - "GROUP_SIZE_M": GROUP_SIZE_M, - "num_warps": num_warps, - "num_stages": num_stages, - } - if index % split_count == split_id: - yield t_config - index += 1 - else: - index += 1 - - -def tuning_configs( - device_id: int, # use for mult mp tunning - device_count: int, - expert_num: int, - m: int, - n: int, - k: int, - topk: int, - dtype: torch.dtype, - test_count: int, - use_fp8_w8a8: bool, - is_up: bool, - block_shape, - num_fused_shared_experts: int, -): - os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) - best_config, best_cost_time = None, 10000000 - queue = mp.Queue() - test_configs = [] - for t_config in get_test_configs(device_id, device_count): - test_configs.append(t_config) - if len(test_configs) < 256: - continue - - p = mp.Process( - target=worker, - args=( - expert_num, - m, - n, - k, - topk, - dtype, - test_count, - use_fp8_w8a8, - is_up, - block_shape, - num_fused_shared_experts, - test_configs, - queue, - ), - ) - p.start() - p.join() - while len(test_configs) != 0: - try: - cost_time = queue.get_nowait() - logger.info(f"get {test_configs[0]} cost_time: {cost_time}") - if cost_time < best_cost_time: - best_config = test_configs[0] - best_cost_time = cost_time - logger.info(f"cur best : {best_config} {best_cost_time}") - del test_configs[0:1] - except: - del test_configs[0:16] - logger.info(f"cur best : {best_config} {best_cost_time}") - break - - while len(test_configs) != 0: - p = mp.Process( - target=worker, - args=( - expert_num, - m, - n, - k, - topk, - dtype, - test_count, - use_fp8_w8a8, - is_up, - block_shape, - num_fused_shared_experts, - test_configs, - queue, - ), - ) - p.start() - p.join() - - while len(test_configs) != 0: - try: - cost_time = queue.get_nowait() - logger.info(f"get {test_configs[0]} cost_time: {cost_time}") - if cost_time < best_cost_time: - best_config = test_configs[0] - best_cost_time = cost_time - logger.info(f"cur best : {best_config} {best_cost_time}") - del test_configs[0:1] - except: - del test_configs[0:16] - logger.info(f"cur best : {best_config} {best_cost_time}") - break - - logger.info(f"{best_config} best cost: {best_cost_time}") - return best_config, best_cost_time - - -def main(args): - torch.multiprocessing.set_start_method("spawn") - from lightllm.utils.tuning_utils import mp_tuning - from lightllm.common.fused_moe.moe_kernel_configs import MoeGroupedGemmKernelConfig - - config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) - if config.architectures[0] == "Qwen3MoeForCausalLM": - expert_num = config.num_experts - topk_num = config.num_experts_per_tok - n = config.moe_intermediate_size // args.tp - elif config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"]: - expert_num = config.n_routed_experts - topk_num = config.num_experts_per_tok - n = config.moe_intermediate_size // args.tp - else: - pass - - hidden_dim = getattr(config, "hidden_size", None) or config.text_config.hidden_size - print(n, hidden_dim) - use_fp8_w8a8 = args.use_fp8_w8a8 - block_shape = None - if hasattr(config, "quantization_config") and "weight_block_size" in config.quantization_config: - block_shape = config.quantization_config["weight_block_size"] - assert len(block_shape) == 2 - use_fp8_w8a8 = True - - up_dict = {} - for m in [1, 8, 64, 128, 256, 512, 1024, 4096, 8192]: - ans = mp_tuning( - tuning_configs, - { - "expert_num": expert_num, - "m": m, - "n": n, - "k": hidden_dim, - "topk": topk_num, - "dtype": torch.bfloat16, - "test_count": 20, - "use_fp8_w8a8": use_fp8_w8a8, - "is_up": True, - "block_shape": block_shape, - "num_fused_shared_experts": args.num_fused_shared_experts, - }, - ) - up_dict[m] = ans - MoeGroupedGemmKernelConfig.save_config( - N=n * 2, - K=hidden_dim, - topk_num=topk_num, - expert_num=expert_num + args.num_fused_shared_experts, - mul_routed_weight=False, - use_fp8_w8a8=use_fp8_w8a8, - out_dtype=str(torch.bfloat16), - config_json=up_dict, - ) - - down_dict = {} - for m in [1, 8, 64, 128, 256, 512, 1024, 4096, 8192, 16384, 32768]: - ans = mp_tuning( - tuning_configs, - { - "expert_num": expert_num, - "m": m, - "n": n, - "k": hidden_dim, - "topk": topk_num, - "dtype": torch.bfloat16, - "test_count": 20, - "use_fp8_w8a8": use_fp8_w8a8, - "is_up": False, - "block_shape": block_shape, - "num_fused_shared_experts": args.num_fused_shared_experts, - }, - ) - down_dict[m] = ans - - MoeGroupedGemmKernelConfig.save_config( - N=hidden_dim, - K=n, - topk_num=1, - expert_num=expert_num + args.num_fused_shared_experts, - mul_routed_weight=True, - use_fp8_w8a8=use_fp8_w8a8, - out_dtype=str(torch.bfloat16), - config_json=down_dict, - ) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--model_dir", type=str, default="deepseek-ai/DeepSeek-R1") - parser.add_argument("--tp", type=int, default=8) - parser.add_argument("--use_fp8_w8a8", action="store_true") - parser.add_argument("--num_fused_shared_experts", type=int, default=0) - args = parser.parse_args() - main(args) diff --git a/test/kernel/moe_silu_and_mul_tuning_bf16.py b/test/kernel/moe_silu_and_mul_tuning_bf16.py deleted file mode 100644 index 038480c3e9..0000000000 --- a/test/kernel/moe_silu_and_mul_tuning_bf16.py +++ /dev/null @@ -1,217 +0,0 @@ -import os -import torch -import time -import torch.multiprocessing as mp -import itertools -from lightllm.common.fused_moe.moe_silu_and_mul import MoeSiluAndMulKernelConfig, silu_and_mul_fwd -from lightllm.utils.watchdog_utils import Watchdog -from typing import List -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -def set_seed(): - import torch - import random - import numpy as np - - seed = 42 - torch.manual_seed(seed) - random.seed(seed) - np.random.seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - return - - -@torch.no_grad() -def test_kernel( - m: int, - n: int, - dtype: torch.dtype, - test_count: int, - **config, -): - set_seed() - input_tuples = [] - - input = torch.randn((m, 2 * n), device="cuda", dtype=dtype) / 10 - output = torch.randn((m, n), device="cuda", dtype=dtype) - - for _ in range(test_count): - input_tuples.append((input.clone(), output.clone())) - - # warm_up - silu_and_mul_fwd(input, output, run_config=config) - - graph = torch.cuda.CUDAGraph() - - with torch.cuda.graph(graph): - for index in range(test_count): - input, output = input_tuples[index] - silu_and_mul_fwd(input, output, run_config=config) - - graph.replay() - - torch.cuda.synchronize() - start = time.time() - graph.replay() - torch.cuda.synchronize() - - cost_time = (time.time() - start) * 1000 - - logger.info(str(config)) - logger.info(f"bf16 {m} cost time: {cost_time} ms") - return cost_time - - -def worker( - m: int, - n: int, - dtype: torch.dtype, - test_count: int, - test_configs, - queue, -): - dog = Watchdog(timeout=10) - dog.start() - try: - for index in range(len(test_configs)): - cost_time = test_kernel( - m=m, - n=n, - dtype=dtype, - test_count=test_count, - **test_configs[index], - ) - dog.heartbeat() - queue.put(cost_time) # Put result in queue - - except Exception as ex: - logger.error(str(ex)) - logger.exception(str(ex)) - import sys - - sys.exit(-1) - pass - - -def get_test_configs(split_id, split_count): - index = 0 - result = itertools.product([1, 2, 4, 8, 16, 32], [64, 128, 256, 512, 1024], [1, 2, 4, 8, 16], [1, 2, 4, 8, 16]) - for BLOCK_M, BLOCK_N, num_warps, NUM_STAGES in result: - t_config = { - "BLOCK_M": BLOCK_M, - "BLOCK_N": BLOCK_N, - "num_warps": num_warps, - "NUM_STAGES": NUM_STAGES, - } - if index % split_count == split_id: - yield t_config - index += 1 - else: - index += 1 - - -def tuning_configs( - device_id: int, # use for mult mp tunning - device_count: int, - m: int, - n: int, - dtype: torch.dtype, - test_count: int, -): - os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) - best_config, best_cost_time = None, 10000000 - queue = mp.Queue() - test_configs = [] - for t_config in get_test_configs(device_id, device_count): - test_configs.append(t_config) - if len(test_configs) < 256: - continue - - p = mp.Process( - target=worker, - args=( - m, - n, - dtype, - test_count, - test_configs, - queue, - ), - ) - p.start() - p.join() - while len(test_configs) != 0: - try: - cost_time = queue.get_nowait() - logger.info(f"get {test_configs[0]} cost_time: {cost_time}") - if cost_time < best_cost_time: - best_config = test_configs[0] - best_cost_time = cost_time - logger.info(f"cur best : {best_config} {best_cost_time}") - del test_configs[0:1] - except: - del test_configs[0:16] - logger.info(f"cur best : {best_config} {best_cost_time}") - break - - while len(test_configs) != 0: - p = mp.Process( - target=worker, - args=( - m, - n, - dtype, - test_count, - test_configs, - queue, - ), - ) - p.start() - p.join() - - while len(test_configs) != 0: - try: - cost_time = queue.get_nowait() - logger.info(f"get {test_configs[0]} cost_time: {cost_time}") - if cost_time < best_cost_time: - best_config = test_configs[0] - best_cost_time = cost_time - logger.info(f"cur best : {best_config} {best_cost_time}") - del test_configs[0:1] - except: - del test_configs[0:16] - logger.info(f"cur best : {best_config} {best_cost_time}") - break - - logger.info(f"{best_config} best cost: {best_cost_time}") - return best_config, best_cost_time - - -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") - from lightllm.utils.tuning_utils import mp_tuning - - # tuning to get silu and mul - for n in [128, 2304, 192, 256, 512, 1024, 1408, 2048, 4096, 8192]: - json_dict = {} - for m in [1, 8, 64, 128, 200, 256, 512, 1024, 2048, 4096, 8192]: - ans = mp_tuning( - tuning_configs, - { - "m": m, - "n": n, - "dtype": torch.bfloat16, - "test_count": 20, - }, - ) - json_dict[m] = ans - MoeSiluAndMulKernelConfig.save_config( - N=n, - out_dtype=str(torch.bfloat16), - config_json=json_dict, - )