Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,67 +1,8 @@
import torch
import triton
import triton.language as tl

from lightllm.common.kernel_config import KernelConfigs
from frozendict import frozendict
from functools import lru_cache
from typing import Any, Dict, List, Optional, Tuple
from triton import Config
from lightllm.common.triton_utils.autotuner import autotune


class Fp8BlockMMKernelConfig(KernelConfigs):
kernel_name: str = "fp8_block_mm"

@classmethod
@lru_cache(maxsize=200)
def try_to_get_best_config(
cls,
M: int,
N: int,
K: int,
block_size: Tuple[int, int],
out_dtype: str,
) -> dict:
key_params = {
"N": N,
"K": K,
"block_size": block_size,
"out_dtype": str(out_dtype),
}
key_params = frozendict(key_params)

finded_config = cls.get_the_config(key_params)

if finded_config:
# find by M
config: dict = finded_config[min(finded_config.keys(), key=lambda x: abs(int(x) - M))]
return config
else:
config = {
"BLOCK_M": 64,
"BLOCK_N": block_size[0],
"BLOCK_K": block_size[1],
"GROUP_M": 32,
"num_warps": 4,
"num_stages": 3,
}
return config

@classmethod
def save_config(
cls, N: int, K: int, block_size: Tuple[int, int], out_dtype: str, config_json: Dict[int, Dict[int, Dict]]
):

key_params = {
"N": N,
"K": K,
"block_size": block_size,
"out_dtype": str(out_dtype),
}
key_params = frozendict(key_params)

return cls.store_config(key_params, config_json)
from typing import List


@triton.jit
Expand Down Expand Up @@ -215,9 +156,14 @@ def w8a8_block_fp8_matmul(
assert triton.cdiv(K, block_k) == Ascale.shape[-1] and Ascale.shape[-1] == Bscale.shape[0]
assert triton.cdiv(N, block_n) == Bscale.shape[1]
if not run_config:
run_config = Fp8BlockMMKernelConfig.try_to_get_best_config(
M=M, N=N, K=K, block_size=block_size, out_dtype=dtype
)
run_config = {
"BLOCK_M": 64,
"BLOCK_N": block_size[0],
"BLOCK_K": block_size[1],
Comment on lines +161 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There seems to be a mix-up in assigning BLOCK_N and BLOCK_K from block_size. block_size[0] is block_k (for the K dimension) and block_size[1] is block_n (for the N dimension). However, BLOCK_N is being assigned block_size[0] and BLOCK_K is being assigned block_size[1]. This should be swapped to align the tiling dimensions correctly with the quantization block dimensions.

Suggested change
"BLOCK_N": block_size[0],
"BLOCK_K": block_size[1],
"BLOCK_N": block_size[1],
"BLOCK_K": block_size[0],

"GROUP_M": 32,
"num_warps": 4,
"num_stages": 3,
}
grid = (triton.cdiv(M, run_config["BLOCK_M"]) * triton.cdiv(N, run_config["BLOCK_N"]),)
_block_scaled_block_gemm[grid](
A,
Expand Down
Loading