diff --git a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh index 3bd148bb601..e7463154c43 100644 --- a/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh +++ b/custom_ops/gpu_ops/append_attn/multiquery_attention_c16_impl.cuh @@ -430,8 +430,7 @@ template + typename OutT = T> __global__ void multi_query_append_attention_warp1_4_kernel( T *__restrict__ q, // [token_num, (num_heads + 2* kv_num_head) * head_dim] T *__restrict__ cache_k, // [max_block_num, num_heads, block_size, @@ -525,17 +524,11 @@ __global__ void multi_query_append_attention_warp1_4_kernel( if (!partition_kv || num_chunks_this_seq <= 1) { o_base_ptr_int8 = out + o_offset; } else { - if (ENABLE_PREFILL) { - o_base_ptr_T = tmp_workspace + batch_id * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } else { - o_base_ptr_T = - tmp_workspace + - batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + - chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + - tid % 8 * num_elems_per_128b(); - } + o_base_ptr_T = + tmp_workspace + + batch_id * speculate_max_draft_token_num * num_chunks * q_n_stride + + chunk_idx * q_n_stride + q_head_idx * HEAD_DIM + + tid % 8 * num_elems_per_128b(); } const int *mask_offset_this_seq = mask_offset ? mask_offset + q_start_seq_id * 2 : nullptr; @@ -799,18 +792,12 @@ __global__ void multi_query_append_attention_warp1_4_kernel( const uint32_t qo_idx = q_start_seq_id + qo_idx_now / GROUP_SIZE; if (qo_idx - q_start_seq_id < q_len) { - uint32_t offset; - if (ENABLE_PREFILL) { - offset = (batch_id * num_chunks + chunk_idx) * q_num_heads + - qo_head_idx; - } else { - offset = ((batch_id * speculate_max_draft_token_num + - qo_idx_now / GROUP_SIZE) * - num_chunks + - chunk_idx) * - q_num_heads + - qo_head_idx; - } + const uint32_t offset = ((batch_id * speculate_max_draft_token_num + + qo_idx_now / GROUP_SIZE) * + num_chunks + + chunk_idx) * + q_num_heads + + qo_head_idx; tmp_m[offset] = m_frag[fx][j]; tmp_d[offset] = d_frag[fx][j]; } @@ -1123,8 +1110,7 @@ void MultiQueryAppendAttention( num_frags_x, num_frags_z, num_frags_y, - OUT_NV_TYPE, - ENABLE_PREFILL>; + OUT_NV_TYPE>; if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(split_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -1169,8 +1155,7 @@ void MultiQueryAppendAttention( num_frags_x, num_frags_z, num_frags_y, - OUT_NV_TYPE, - ENABLE_PREFILL>; + OUT_NV_TYPE>; if (smem_size >= 48 * 1024) { cudaFuncSetAttribute(nosplit_kv_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, @@ -1222,43 +1207,18 @@ void MultiQueryAppendAttention( sink_size); } else { phi::Allocator::AllocationPtr tmp_workspace, tmp_m, tmp_d; - if (is_decoder) { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(bsz * num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(bsz * num_chunks * num_heads)); - } else { - if (ENABLE_PREFILL) { - tmp_workspace = - allocator->Allocate(phi::SizeOf(qkv.dtype()) * - static_cast(token_num * num_chunks * - num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(token_num * num_chunks * num_heads)); - } else { - tmp_workspace = allocator->Allocate( - phi::SizeOf(qkv.dtype()) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads * HEAD_DIM)); - tmp_m = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - tmp_d = allocator->Allocate( - phi::SizeOf(paddle::DataType::FLOAT32) * - static_cast(speculate_max_draft_token_num * bsz * - num_chunks * num_heads)); - } - } + tmp_workspace = allocator->Allocate( + phi::SizeOf(qkv.dtype()) * + static_cast(speculate_max_draft_token_num * bsz * num_chunks * + num_heads * HEAD_DIM)); + tmp_m = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * num_chunks * + num_heads)); + tmp_d = allocator->Allocate( + phi::SizeOf(paddle::DataType::FLOAT32) * + static_cast(speculate_max_draft_token_num * bsz * num_chunks * + num_heads)); launchWithPdlWhenEnabled( split_kv_kernel, grids, diff --git a/fastdeploy/model_executor/layers/attention/append_attn_backend.py b/fastdeploy/model_executor/layers/attention/append_attn_backend.py index 81eab7cce86..c73283b48de 100644 --- a/fastdeploy/model_executor/layers/attention/append_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/append_attn_backend.py @@ -146,6 +146,8 @@ def __init__( self.causal: bool = getattr(fd_config.model_config, "causal", True) self.speculative_method = fd_config.speculative_config.method self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens + if self.speculative_method is None: + self.speculate_max_draft_token_num = 0 self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) diff --git a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py index b51dce1449d..3ac0bfafbb6 100644 --- a/fastdeploy/model_executor/layers/attention/flash_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_attn_backend.py @@ -258,6 +258,8 @@ def __init__( self.speculative_method = fd_config.speculative_config.method self.use_speculate = self.speculative_method is not None self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens + if not self.use_speculate: + self.speculate_max_draft_token_num = 0 self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP) diff --git a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py index 35d27504ab5..6e05ca0c3b8 100644 --- a/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py +++ b/fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py @@ -109,6 +109,8 @@ def __init__( self.speculative_method = fd_config.speculative_config.method self.use_speculate = self.speculative_method is not None self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens + if not self.use_speculate: + self.speculate_max_draft_token_num = 0 self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp" self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)