From 4523ef6f562884f09b03a7f165f486fa0e2a5522 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Thu, 19 Mar 2026 07:52:35 +0000 Subject: [PATCH] fix _block_scaled_block_gemm kernel, remove config class --- .../quantization/fp8w8a8_block_gemm_kernel.py | 72 +++---------------- 1 file changed, 9 insertions(+), 63 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py index 91341009f0..9e18991fea 100644 --- a/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py +++ b/lightllm/common/basemodel/triton_kernel/quantization/fp8w8a8_block_gemm_kernel.py @@ -1,67 +1,8 @@ import torch import triton import triton.language as tl - -from lightllm.common.kernel_config import KernelConfigs -from frozendict import frozendict -from functools import lru_cache -from typing import Any, Dict, List, Optional, Tuple -from triton import Config from lightllm.common.triton_utils.autotuner import autotune - - -class Fp8BlockMMKernelConfig(KernelConfigs): - kernel_name: str = "fp8_block_mm" - - @classmethod - @lru_cache(maxsize=200) - def try_to_get_best_config( - cls, - M: int, - N: int, - K: int, - block_size: Tuple[int, int], - out_dtype: str, - ) -> dict: - key_params = { - "N": N, - "K": K, - "block_size": block_size, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - finded_config = cls.get_the_config(key_params) - - if finded_config: - # find by M - config: dict = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))] - return config - else: - config = { - "BLOCK_M": 64, - "BLOCK_N": block_size[0], - "BLOCK_K": block_size[1], - "GROUP_M": 32, - "num_warps": 4, - "num_stages": 3, - } - return config - - @classmethod - def save_config( - cls, N: int, K: int, block_size: Tuple[int, int], out_dtype: str, config_json: Dict[int, Dict[int, Dict]] - ): - - key_params = { - "N": N, - "K": K, - "block_size": block_size, - "out_dtype": str(out_dtype), - } - key_params = frozendict(key_params) - - return cls.store_config(key_params, config_json) +from typing import List @triton.jit @@ -215,9 +156,14 @@ def w8a8_block_fp8_matmul( assert triton.cdiv(K, block_k) == Ascale.shape[-1] and Ascale.shape[-1] == Bscale.shape[0] assert triton.cdiv(N, block_n) == Bscale.shape[1] if not run_config: - run_config = Fp8BlockMMKernelConfig.try_to_get_best_config( - M=M, N=N, K=K, block_size=block_size, out_dtype=dtype - ) + run_config = { + "BLOCK_M": 64, + "BLOCK_N": block_size[0], + "BLOCK_K": block_size[1], + "GROUP_M": 32, + "num_warps": 4, + "num_stages": 3, + } grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),) _block_scaled_block_gemm[grid]( A,