Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 26 additions & 66 deletions custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -430,8 +430,7 @@ template <typename T,
uint32_t num_frags_x,
uint32_t num_frags_z,
uint32_t num_frags_y,
typename OutT = T,
bool ENABLE_PREFILL = true>
typename OutT = T>
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 只有 multi_query_append_attention_warp1_4_kernel 删除了 ENABLE_PREFILL 模板参数

注意到 multiquery_attention_c4_impl.cuhmultiquery_attention_c8_impl.cuh 中仍然保留了 ENABLE_PREFILL 参数。这是否是有意为之(分阶段重构),还是遗漏?

如果是有意为之,建议在 PR 描述中说明原因。

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

分阶段重构

__global__ void multi_query_append_attention_warp1_4_kernel(
T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim]
T *__restrict__ cache_k, // [max_block_num, num_heads, block_size,
Expand Down Expand Up @@ -525,17 +524,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
if (!partition_kv || num_chunks_this_seq <= 1) {
o_base_ptr_int8 = out + o_offset;
} else {
if (ENABLE_PREFILL) {
o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
} else {
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
o_base_ptr_T =
tmp_workspace +
batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride +
chunk_idx * q_n_stride + q_head_idx * HEAD_DIM +
tid % 8 * num_elems_per_128b<T>();
}
const int *mask_offset_this_seq =
mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr;
Expand Down Expand Up @@ -799,18 +792,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel(
const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE;

if (qo_idx - q_start_seq_id < q_len) {
uint32_t offset;
if (ENABLE_PREFILL) {
offset = (batch_id * num_chunks + chunk_idx) * q_num_heads +
qo_head_idx;
} else {
offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
}
const uint32_t offset = ((batch_id * speculate_max_draft_token_num +
qo_idx_now / GROUP_SIZE) *
num_chunks +
chunk_idx) *
q_num_heads +
qo_head_idx;
tmp_m[offset] = m_frag[fx][j];
tmp_d[offset] = d_frag[fx][j];
}
Expand Down Expand Up @@ -1123,8 +1110,7 @@ void MultiQueryAppendAttention(
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL>;
OUT_NV_TYPE>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(split_kv_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Expand Down Expand Up @@ -1169,8 +1155,7 @@ void MultiQueryAppendAttention(
num_frags_x,
num_frags_z,
num_frags_y,
OUT_NV_TYPE,
ENABLE_PREFILL>;
OUT_NV_TYPE>;
if (smem_size >= 48 * 1024) {
cudaFuncSetAttribute(nosplit_kv_kernel,
cudaFuncAttributeMaxDynamicSharedMemorySize,
Expand Down Expand Up @@ -1222,43 +1207,18 @@ void MultiQueryAppendAttention(
sink_size);
} else {
phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d;
if (is_decoder) {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(bsz * num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(bsz * num_chunks * num_heads));
} else {
if (ENABLE_PREFILL) {
tmp_workspace =
allocator->Allocate(phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(token_num * num_chunks *
num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(token_num * num_chunks * num_heads));
} else {
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz *
num_chunks * num_heads));
}
}
tmp_workspace = allocator->Allocate(
phi::SizeOf(qkv.dtype()) *
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
num_heads * HEAD_DIM));
tmp_m = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
num_heads));
tmp_d = allocator->Allocate(
phi::SizeOf(paddle::DataType::FLOAT32) *
static_cast<size_t>(speculate_max_draft_token_num * bsz * num_chunks *
num_heads));
launchWithPdlWhenEnabled(
split_kv_kernel,
grids,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def __init__(
self.causal: bool = getattr(fd_config.model_config, "causal", True)
self.speculative_method = fd_config.speculative_config.method
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
if self.speculative_method is None:
self.speculate_max_draft_token_num = 0
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 建议添加注释说明此设置的目的

这个修改确保非 speculative 模式下 speculate_max_draft_token_num=0,与 CUDA 端删除 ENABLE_PREFILL 后的统一内存布局配合。建议添加注释说明这个关联,方便后续维护。

# When not using speculative decoding, set to 0. The CUDA kernel will
# receive (speculate_max_draft_token_num + 1) = 1, which matches the
# simplified memory layout after removing ENABLE_PREFILL branches.
if self.speculative_method is None:
    self.speculate_max_draft_token_num = 0

self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ def __init__(
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
if not self.use_speculate:
self.speculate_max_draft_token_num = 0
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def __init__(
self.speculative_method = fd_config.speculative_config.method
self.use_speculate = self.speculative_method is not None
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
if not self.use_speculate:
self.speculate_max_draft_token_num = 0
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)

Expand Down
Loading