Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions lightllm/common/basemodel/attention/create_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,13 @@
},
"int4kv": {
"triton": Int4kvTritonAttBackend,
"fa3": Fp8Fa3AttBackend,
"flashinfer": Fp8FlashInferAttBackend,
# "fa3": Fp8Fa3AttBackend,
# "flashinfer": Fp8FlashInferAttBackend,
},
"int8kv": {
"triton": Int8kvTritonAttBackend,
"fa3": Fp8Fa3AttBackend,
"flashinfer": Fp8FlashInferAttBackend,
# "fa3": Fp8Fa3AttBackend,
# "flashinfer": Fp8FlashInferAttBackend,
},
}

Expand Down Expand Up @@ -66,7 +66,7 @@ def _auto_select_backend(
backend_map = kv_type_to_backend

for backend_name in priority_list:
if validate(backend_name):
if backend_name in backend_map[llm_dtype] and validate(backend_name):
logger.info(f"Auto-selected {backend_name} backend (validated)")
return backend_map[llm_dtype][backend_name]

Expand Down
10 changes: 4 additions & 6 deletions lightllm/common/basemodel/attention/triton/int8kv.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def decode_att(
if enable_diverse_mode_gqa_decode_fast_kernel():
return self.diverse_decode_att(q=q, k=k, k_scale=k_scale, v=v, v_scale=v_scale, alloc_func=alloc_func)
else:
return self.ppl_mha_int8kv_decode_att(
return self.normal_decode_att(
q=q,
k=k,
k_scale=k_scale,
Expand Down Expand Up @@ -172,18 +172,16 @@ def diverse_decode_att(
alloc_tensor_func=alloc_func,
)

def ppl_mha_int8kv_decode_att(
def normal_decode_att(
self,
q: torch.Tensor,
k: torch.Tensor,
k_scale: torch.Tensor,
v: torch.Tensor,
v_scale: torch.Tensor,
alloc_func=torch.empty,
) -> torch.Tensor:
from ...triton_kernel.att.decode_att.int8kv.ppl_int8kv_flash_decoding import (
token_decode_attention_flash_decoding,
)
):
from ...triton_kernel.att.decode_att.int8kv.normal import token_decode_attention_flash_decoding

return token_decode_attention_flash_decoding(
q=q,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .int8kv_flash_decoding import token_decode_attention_flash_decoding
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
from typing import Optional
from lightllm.common.basemodel.infer_struct import InferStateInfo
from .int8kv_flash_decoding_stage1 import flash_decode_stage1
from .int8kv_flash_decoding_stage2 import flash_decode_stage2


@torch.no_grad()
def token_decode_attention_flash_decoding(
q: torch.Tensor,
infer_state: InferStateInfo,
cache_k: torch.Tensor,
cache_k_scale: torch.Tensor,
cache_v: torch.Tensor,
cache_v_scale: torch.Tensor,
out: Optional[torch.Tensor] = None,
alloc_tensor_func=torch.empty,
):

q_head_num = q.shape[1]
head_dim = q.shape[2]

BLOCK_SEQ = 256
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The value 256 for BLOCK_SEQ appears to be a magic number. It would be beneficial to define this as a named constant (e.g., DEFAULT_BLOCK_SEQ) to improve readability and make its purpose clearer. If this value can vary, consider making it a configurable parameter.

Suggested change
BLOCK_SEQ = 256
DEFAULT_BLOCK_SEQ = 256
BLOCK_SEQ = DEFAULT_BLOCK_SEQ

batch_size = infer_state.batch_size
calcu_shape1 = (batch_size, q_head_num, head_dim)

o_tensor = alloc_tensor_func(q.shape, q.dtype, q.device) if out is None else out

# 因为需要分配一些中间tensor,考虑到并行度和中间显存的消耗,batch_size 小的
# 时候 block_num 较大, batch_size 大的时候 block_num 较小。这样可以达到较好
# 的显存消耗和性能的平衡。
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32
Comment on lines +32 to +37
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The if/elif/else block uses several magic numbers (16, 64, 128, 64, 32) to determine block_num. These thresholds and values could be made more explicit by defining them as named constants or by providing comments explaining the rationale behind these specific values. This would improve maintainability and make it easier to understand the performance tuning logic.

Suggested change
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32
# Heuristic to balance parallelism and memory consumption based on batch size
if batch_size <= 16:
block_num = 128
elif batch_size <= 64:
block_num = 64
else:
block_num = 32


mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device="cuda" is hardcoded. While this function is likely intended for CUDA, it's generally better practice to infer the device from input tensors (e.g., q.device) for greater flexibility and to avoid potential issues if the input tensors are on a different device.

Suggested change
mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device="cuda")
mid_o = alloc_tensor_func([batch_size, q_head_num, block_num, head_dim], dtype=q.dtype, device=q.device)

mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The device="cuda" is hardcoded. It's generally better to infer the device from input tensors (e.g., q.device). Additionally, torch.float32 is hardcoded for mid_o_logexpsum. If q.dtype is different (e.g., bfloat16), consider if float32 is strictly necessary here or if q.dtype could be used for consistency, or add a comment explaining the float32 requirement.

Suggested change
mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device="cuda")
mid_o_logexpsum = alloc_tensor_func([batch_size, q_head_num, block_num], dtype=torch.float32, device=q.device)


flash_decode_stage1(
q=q.view(calcu_shape1),
k=cache_k,
k_scale=cache_k_scale,
v=cache_v,
v_scale=cache_v_scale,
Req_to_tokens=infer_state.req_manager.req_to_token_indexs,
B_req_idx=infer_state.b_req_idx,
B_seq_len=infer_state.b_seq_len,
max_len_in_batch=infer_state.max_kv_seq_len,
mid_out=mid_o,
mid_out_logsumexp=mid_o_logexpsum,
block_seq=BLOCK_SEQ,
)

flash_decode_stage2(
mid_out=mid_o,
mid_out_logexpsum=mid_o_logexpsum,
B_Seqlen=infer_state.b_seq_len,
O=o_tensor.view(calcu_shape1),
block_seq=BLOCK_SEQ,
)
return o_tensor
Loading
Loading