diff --git a/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu b/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu index 529bfd9ab0e..778a6112367 100644 --- a/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu +++ b/custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu @@ -12,19 +12,215 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include -#include #include -#include #include +#include +#include +#include #include "paddle/extension.h" +#include "../ngram_match_common.cuh" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) #endif -int sum_mixed(const int *value, int num) { +// ============================================================ +// Phase 1 mixed search kernel — one block per batch item. +// Also copies tentative matched tokens to scratch buffers. +// ============================================================ +__global__ void ngram_match_mixed_search_kernel( + const int64_t *input_ids, + const int64_t *input_ids_len, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *draft_token_num, + const int32_t *seq_lens_this_time, + const int64_t *max_dec_len, + int64_t *draft_tokens_copy, + int32_t *seq_lens_this_time_copy, + int64_t input_ids_stride, + int64_t pre_ids_stride, + int64_t draft_tokens_stride, + int64_t max_batch_size, + int max_ngram_size, + int min_ngram_size, + int max_draft_tokens_param) { + int batch_idx = blockIdx.x; + if (batch_idx >= max_batch_size) return; + + __shared__ int64_t s_min_pos; + + const int ori_seq_len_this_time = seq_lens_this_time[batch_idx]; + + if (threadIdx.x == 0) { + // Default: keep the original seq_lens_this_time (no ngram match) + seq_lens_this_time_copy[batch_idx] = ori_seq_len_this_time; + } + __syncthreads(); + + // Skip batch items with no active tokens + if (ori_seq_len_this_time == 0) return; + + // Compute max_draft_tokens for this batch item. + // Split into explicit steps to avoid negative intermediate values. + int64_t draft_budget = + static_cast(max_draft_tokens_param) - ori_seq_len_this_time + 1; + int64_t remaining_dec = max_dec_len[batch_idx] - step_idx[batch_idx] - 1; + if (draft_budget <= 0 || remaining_dec <= 0) return; + int max_draft_tokens = static_cast(min(draft_budget, remaining_dec)); + + const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; + const int64_t cur_input_ids_len = input_ids_len[batch_idx]; + const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride; + const int64_t cur_step_idx = step_idx[batch_idx]; + + for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size; + --ngram_size) { + if (cur_step_idx < ngram_size) continue; + + const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); + + int64_t pos = parallel_ngram_search( + cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos); + if (pos != INT64_MAX) { + int64_t start_idx = pos + ngram_size; + int64_t end_idx = min(start_idx + static_cast(max_draft_tokens), + cur_input_ids_len); + if (threadIdx.x == 0 && start_idx < end_idx) { + // Tentative token copy to scratch + int64_t n = end_idx - start_idx; + seq_lens_this_time_copy[batch_idx] = + static_cast(ori_seq_len_this_time + n); + int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride; + for (int64_t k = 0; k < n; k++) { + dst[ori_seq_len_this_time + k] = cur_input_ids[start_idx + k]; + } + } + // Only early-exit when tokens were actually produced + if (start_idx < end_idx) { + return; + } + } + + pos = parallel_ngram_search( + cur_pre_ids, cur_step_idx, ngram, ngram_size, &s_min_pos); + if (pos != INT64_MAX) { + int64_t start_idx = pos + ngram_size; + int64_t end_idx = + min(start_idx + static_cast(max_draft_tokens), cur_step_idx); + if (threadIdx.x == 0 && start_idx < end_idx) { + // Tentative token copy to scratch + int64_t n = end_idx - start_idx; + seq_lens_this_time_copy[batch_idx] = + static_cast(ori_seq_len_this_time + n); + int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride; + for (int64_t k = 0; k < n; k++) { + dst[ori_seq_len_this_time + k] = cur_pre_ids[start_idx + k]; + } + } + // Only early-exit when tokens were actually produced + if (start_idx < end_idx) { + return; + } + } + } +} + +// ============================================================ +// Phase 2 mixed gather kernel — BlockScan threshold + copy +// <<<1, NGRAM_GATHER_THREADS>>> +// +// Reads tentative allocations from Phase 1 scratch buffers, +// computes prefix sums to enforce the global threshold, then +// writes final seq_lens_this_time and copies draft tokens. +// The mixed variant respects ori_seq_len_this_time (MTP tokens). +// ============================================================ +__global__ void ngram_match_mixed_gather_kernel( + const int64_t *draft_tokens_copy, + const int32_t *seq_lens_this_time_copy, + const int32_t *seq_lens_this_time_orig, + int64_t *draft_tokens, + int32_t *seq_lens_this_time, + int64_t draft_tokens_stride, + int64_t max_batch_size, + int threshold) { + typedef cub::BlockScan BlockScanInt; + __shared__ typename BlockScanInt::TempStorage temp_storage1; + __shared__ typename BlockScanInt::TempStorage temp_storage2; + __shared__ int s_total_active; + + int tid = threadIdx.x; + + // Load tentative total token count from Phase 1 + int tentative = 0; + int is_active = 0; + if (tid < max_batch_size) { + tentative = seq_lens_this_time_copy[tid]; + is_active = (tentative > 0) ? 1 : 0; + } + + // Scan 1: inclusive prefix sum of tentative token counts + int token_prefix; + BlockScanInt(temp_storage1).InclusiveSum(tentative, token_prefix); + __syncthreads(); + + // Scan 2: inclusive prefix sum of active-item indicators + int active_prefix; + BlockScanInt(temp_storage2).InclusiveSum(is_active, active_prefix); + __syncthreads(); + + // Total active count from the last valid thread + if (tid == + min(static_cast(max_batch_size) - 1, NGRAM_GATHER_THREADS - 1)) { + s_total_active = active_prefix; + } + __syncthreads(); + + if (tid < max_batch_size) { + if (tentative == 0) { + seq_lens_this_time[tid] = 0; + return; + } + + int ori = seq_lens_this_time_orig[tid]; + int ngram_tokens = tentative - ori; // tokens added by ngram match + + int exclusive_token_prefix = token_prefix - tentative; + int remaining_active = s_total_active - active_prefix; + + // Budget: threshold minus tokens already allocated before me, + // minus at-least-1 reservation for every active item after me. + int budget = threshold - exclusive_token_prefix - remaining_active; + + int actual; + if (budget <= ori) { + // Can't even keep all MTP base tokens — keep original only + actual = ori; + } else { + int ngram_budget = budget - ori; + actual = ori + min(ngram_tokens, ngram_budget); + } + actual = min(actual, tentative); + + seq_lens_this_time[tid] = actual; + + // Copy ngram draft tokens from scratch to output + int ngram_to_copy = actual - ori; + if (ngram_to_copy > 0) { + int64_t *dst = draft_tokens + tid * draft_tokens_stride; + const int64_t *src = draft_tokens_copy + tid * draft_tokens_stride; + for (int k = 0; k < ngram_to_copy; k++) { + dst[ori + k] = src[ori + k]; + } + } + } +} + +// ============================================================ +// CPU path — preserved from original for backward compatibility +// with CPU-only callers and tests. +// ============================================================ +static int sum_mixed_cpu(const int *value, int num) { int sum_value = 0; for (int i = 0; i <= num; i++) { sum_value += value[i]; @@ -32,24 +228,23 @@ int sum_mixed(const int *value, int num) { return sum_value; } -void find_candidate_pred_tokens_mixed(const int64_t *input_ids, - const int64_t *input_ids_len, - const int64_t *pre_ids, - const int64_t *step_idx, - const int *draft_token_num, - int64_t *draft_tokens, - int32_t *seq_lens_this_time, - int32_t *seq_lens_decoder, - int64_t *max_dec_len, - int64_t input_ids_stride, - int64_t pre_ids_stride, - int64_t draft_tokens_stride, - int64_t max_batch_size, - int max_ngram_size = 3, - int min_ngram_size = 1, - const int max_draft_tokens = 10) { +static void find_candidate_pred_tokens_mixed(const int64_t *input_ids, + const int64_t *input_ids_len, + const int64_t *pre_ids, + const int64_t *step_idx, + const int *draft_token_num, + int64_t *draft_tokens, + int32_t *seq_lens_this_time, + int32_t *seq_lens_decoder, + int64_t *max_dec_len, + int64_t input_ids_stride, + int64_t pre_ids_stride, + int64_t draft_tokens_stride, + int64_t max_batch_size, + int max_ngram_size = 3, + int min_ngram_size = 1, + const int max_draft_tokens = 10) { int threshold = 1024; - // dynamic in future char *env_var = getenv("SPEC_TOKENUM_THRESHOLD"); if (env_var) { threshold = std::stoi(env_var); @@ -62,13 +257,16 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids, } for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { const int ori_seq_len_this_time = seq_lens_this_time[batch_idx]; - int max_draft_tokens_query = std::min( - static_cast(max_draft_tokens - ori_seq_len_this_time + 1), - max_dec_len[batch_idx] - step_idx[batch_idx] - 1); + // Split into explicit int64_t steps to avoid negative intermediate values. + int64_t draft_budget = + static_cast(max_draft_tokens) - ori_seq_len_this_time + 1; + int64_t remaining_dec = max_dec_len[batch_idx] - step_idx[batch_idx] - 1; - if (ori_seq_len_this_time == 0 || max_draft_tokens_query <= 0) { + if (ori_seq_len_this_time == 0 || draft_budget <= 0 || remaining_dec <= 0) { continue; } + int max_draft_tokens_query = + static_cast(std::min(draft_budget, remaining_dec)); const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; @@ -77,7 +275,7 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids, const int64_t cur_input_ids_len = input_ids_len[batch_idx]; unprocessed_batch_size--; - auto sum_token_num = sum_mixed(seq_lens_this_time, batch_idx); + auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx); int left_min_token_num = unprocessed_batch_size; if (sum_token_num + max_draft_tokens_query + left_min_token_num > @@ -91,21 +289,16 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids, continue; } bool match_global = false; - // apply ngram_match in input_ids for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size && !match_global; --ngram_size) { - // Extract the last n tokens as our search ngram if (cur_step_idx < ngram_size) { continue; } const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); - // Iterate through sliding windows of size ngram_size - // bool match_input = false; for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global; ++i) { - // Check if the current window matches the ngram bool match_local = true; for (int j = 0; j < ngram_size; j++) { if (ngram[j] != cur_input_ids[i + j]) { @@ -120,24 +313,19 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids, if (start_idx >= end_idx) continue; int64_t cur_draft_token_num = end_idx - start_idx; - seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num; memcpy(cur_draft_tokens + ori_seq_len_this_time, cur_input_ids + start_idx, sizeof(int64_t) * cur_draft_token_num); - // To break the current batch_idx for-loop match_global = true; break; } } - // apply ngram_match in generated tokens if (!match_global) { for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global; ++i) { - // Check if the current window matches the ngram bool match_local = true; - for (int j = 0; j < ngram_size; j++) { if (ngram[j] != cur_pre_ids[i + j]) { match_local = false; @@ -148,13 +336,8 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids, int64_t start_idx = i + ngram_size; int64_t end_idx = std::min(start_idx + max_draft_tokens_query, cur_step_idx); - int64_t cur_draft_token_num = end_idx - start_idx; - if (start_idx >= end_idx) continue; - // printf("match in Output with Ngram_size %d. - // %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx, - // end_idx); seq_lens_this_time[batch_idx] = ori_seq_len_this_time + cur_draft_token_num; @@ -170,6 +353,16 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids, } } +// ============================================================ +// GPU path — Two-phase parallel CUDA kernels for hybrid ngram matching. +// +// Phase 1: <<>> — parallel sliding-window +// search within each batch item (NGRAM_BLOCK_THREADS threads +// per block). Also copies matched draft tokens to scratch. +// Phase 2: <<<1, NGRAM_GATHER_THREADS>>> — CUB BlockScan prefix-sum +// threshold enforcement + final token copy. +// ============================================================ + void HybridMtpNgram(const paddle::Tensor &input_ids, const paddle::Tensor &input_ids_len, const paddle::Tensor &pre_ids, @@ -193,23 +386,101 @@ void HybridMtpNgram(const paddle::Tensor &input_ids, const int64_t max_batch_size = seq_lens_this_time.shape()[0]; - find_candidate_pred_tokens_mixed( - input_ids.data(), - input_ids_len.data(), - pre_ids.data(), - step_idx.data(), - draft_token_num.data(), - const_cast(draft_tokens.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_decoder.data()), - const_cast(max_dec_len.data()), - input_ids_stride, - pre_ids_stride, - draft_tokens_stride, - max_batch_size, - max_ngram_size, - min_ngram_size, - max_draft_tokens); + int threshold = 1024; + const char *env_var = getenv("SPEC_TOKENUM_THRESHOLD"); + if (env_var) { + threshold = std::stoi(env_var); + } + + if (input_ids.is_gpu()) { + auto stream = input_ids.stream(); + + // NOTE: GPU path does not pass seq_lens_decoder to kernels — the mixed + // variant uses ori_seq_len_this_time == 0 to skip inactive items. This + // matches CPU behavior under the invariant that seq_lens_decoder > 0 iff + // ori_seq_len_this_time > 0 (holds during normal MTP decoding). The CPU + // path counts seq_lens_decoder > 0 for threshold budget; the GPU scan + // counts tentative > 0, which is equivalent under this invariant. + + // Allocate scratch buffers for Phase 1 → Phase 2 communication + + // Scratch copy of draft_tokens (Phase 1 writes tentative tokens here) + auto draft_tokens_copy = + paddle::empty({max_batch_size, draft_tokens_stride}, + paddle::DataType::INT64, + input_ids.place()); + + // Scratch copy of seq_lens_this_time (Phase 1 writes tentative counts) + auto seq_lens_this_time_copy = paddle::empty( + {max_batch_size}, paddle::DataType::INT32, input_ids.place()); + + // Save a copy of original seq_lens_this_time for Phase 2 + // (Phase 1 reads from the original, Phase 2 needs ori values) + auto seq_lens_this_time_orig = paddle::empty( + {max_batch_size}, paddle::DataType::INT32, input_ids.place()); + cudaMemcpyAsync(seq_lens_this_time_orig.data(), + seq_lens_this_time.data(), + max_batch_size * sizeof(int32_t), + cudaMemcpyDeviceToDevice, + stream); + + // Fail-fast: BlockScan Phase 2 requires max_batch_size ≤ block size. + PD_CHECK(max_batch_size <= NGRAM_GATHER_THREADS, + "hybrid_mtp_ngram: max_batch_size exceeds NGRAM_GATHER_THREADS"); + + // Phase 1: parallel search — one block per batch item. + // Also copies matched tokens to scratch and writes tentative seq_lens. + ngram_match_mixed_search_kernel<<>>( + input_ids.data(), + input_ids_len.data(), + pre_ids.data(), + step_idx.data(), + draft_token_num.data(), + seq_lens_this_time.data(), + max_dec_len.data(), + draft_tokens_copy.data(), + seq_lens_this_time_copy.data(), + input_ids_stride, + pre_ids_stride, + draft_tokens_stride, + max_batch_size, + max_ngram_size, + min_ngram_size, + max_draft_tokens); + + // Phase 2: BlockScan threshold enforcement + final token copy. + // <<<1, NGRAM_GATHER_THREADS>>> — all batch items handled by one block. + ngram_match_mixed_gather_kernel<<<1, NGRAM_GATHER_THREADS, 0, stream>>>( + draft_tokens_copy.data(), + seq_lens_this_time_copy.data(), + seq_lens_this_time_orig.data(), + const_cast(draft_tokens.data()), + const_cast(seq_lens_this_time.data()), + draft_tokens_stride, + max_batch_size, + threshold); + } else { + find_candidate_pred_tokens_mixed( + input_ids.data(), + input_ids_len.data(), + pre_ids.data(), + step_idx.data(), + draft_token_num.data(), + const_cast(draft_tokens.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(max_dec_len.data()), + input_ids_stride, + pre_ids_stride, + draft_tokens_stride, + max_batch_size, + max_ngram_size, + min_ngram_size, + max_draft_tokens); + } } PD_BUILD_STATIC_OP(hybrid_mtp_ngram) diff --git a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cc b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cc deleted file mode 100644 index 56a2d3f81c3..00000000000 --- a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cc +++ /dev/null @@ -1,227 +0,0 @@ -// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include "paddle/extension.h" - -#ifndef PD_BUILD_STATIC_OP -#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) -#endif - -int sum(const int *value, int num) { - int sum_value = 0; - for (int i = 0; i <= num; i++) { - sum_value += value[i]; - } - return sum_value; -} - -void find_candidate_pred_tokens(const int64_t *input_ids, - const int64_t *input_ids_len, - const int64_t *token_ids_all, - const int64_t *prompt_lens, - const int64_t *step_idx, - const int *draft_token_num, - int64_t *draft_tokens, - int32_t *seq_lens_this_time, - int32_t *seq_lens_encoder, - int32_t *seq_lens_decoder, - int64_t *max_dec_len, - int64_t input_ids_stride, - int64_t max_model_len, - int64_t draft_tokens_stride, - int64_t max_batch_size, - int max_ngram_size = 3, - int max_draft_tokens = 10) { - int threshold = 128; - char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"); - if (env_var) { - threshold = std::stoi(env_var); - } - int unprocessed_batch_size = 0; - for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { - if (seq_lens_encoder[batch_idx] > 0 || seq_lens_decoder[batch_idx] > 0) { - unprocessed_batch_size++; - } - } - for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { - max_draft_tokens = - std::min(static_cast(draft_token_num[batch_idx]), - max_dec_len[batch_idx] - step_idx[batch_idx] - 1); - if (seq_lens_encoder[batch_idx] > 0) { - continue; - } else if (seq_lens_decoder[batch_idx] == 0) { - seq_lens_this_time[batch_idx] = 0; - continue; - } - // printf("bid: %d. enc: %d. dec. %d\n", batch_idx, - // seq_lens_encoder[batch_idx], seq_lens_decoder[batch_idx]); - - const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; - int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; - const int64_t *cur_pre_ids = - token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx]; - const int64_t cur_step_idx = step_idx[batch_idx]; - const int64_t cur_input_ids_len = input_ids_len[batch_idx]; - seq_lens_this_time[batch_idx] = 1; - unprocessed_batch_size--; - - auto sum_token_num = sum(seq_lens_this_time, batch_idx); - int left_min_token_num = unprocessed_batch_size; - - if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) { - int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num; - max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens - ? tmp_max_draft_tokens - : max_draft_tokens; - } - - if (sum_token_num + left_min_token_num >= threshold - 1) { - continue; - } - - for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) { - // Extract the last n tokens as our search ngram - if (cur_step_idx < ngram_size) { - continue; - } - const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); - - // Iterate through sliding windows of size ngram_size - bool match_input = false; - for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) { - // Check if the current window matches the ngram - bool match = true; - for (int j = 0; j < ngram_size; j++) { - if (ngram[j] != cur_input_ids[i + j]) { - match = false; - break; - } - } - if (match) { - int64_t start_idx = i + ngram_size; - int64_t end_idx = - std::min(start_idx + max_draft_tokens, cur_input_ids_len); - if (start_idx >= end_idx) continue; - - int64_t cur_draft_token_num = end_idx - start_idx; - - seq_lens_this_time[batch_idx] = cur_draft_token_num + 1; - memcpy(cur_draft_tokens + 1, - cur_input_ids + start_idx, - sizeof(int64_t) * cur_draft_token_num); - // To break the current batch_idx for-loop - ngram_size = 0; - match_input = true; - break; - // } - } - } - if (!match_input) { - for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) { - // Check if the current window matches the ngram - bool match = true; - - for (int j = 0; j < ngram_size; j++) { - if (ngram[j] != cur_pre_ids[i + j]) { - match = false; - break; - } - } - - if (match) { - int64_t start_idx = i + ngram_size; - int64_t end_idx = - std::min(start_idx + max_draft_tokens, cur_step_idx); - int64_t cur_draft_token_num = end_idx - start_idx; - if (start_idx >= end_idx) continue; - - seq_lens_this_time[batch_idx] = cur_draft_token_num + 1; - memcpy(cur_draft_tokens + 1, - cur_pre_ids + start_idx, - sizeof(int64_t) * cur_draft_token_num); - ngram_size = 0; - break; - } - } - } - } - } -} - -void NgramMatch(const paddle::Tensor &input_ids, - const paddle::Tensor &input_ids_len, - const paddle::Tensor &token_ids_all, - const paddle::Tensor &prompt_lens, - const paddle::Tensor &step_idx, - const paddle::Tensor &draft_token_num, - const paddle::Tensor &draft_tokens, - const paddle::Tensor &seq_lens_this_time, - const paddle::Tensor &seq_lens_encoder, - const paddle::Tensor &seq_lens_decoder, - const paddle::Tensor &max_dec_len, - const int max_ngram_size, - const int max_draft_tokens) { - auto input_ids_shape = input_ids.shape(); - const int64_t input_ids_stride = input_ids_shape[1]; - - const int64_t max_model_len = token_ids_all.shape()[1]; - - auto draft_tokens_shape = draft_tokens.shape(); - const int64_t draft_tokens_stride = draft_tokens_shape[1]; - - const int64_t max_batch_size = seq_lens_this_time.shape()[0]; - - find_candidate_pred_tokens( - input_ids.data(), - input_ids_len.data(), - token_ids_all.data(), - prompt_lens.data(), - step_idx.data(), - draft_token_num.data(), - const_cast(draft_tokens.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_encoder.data()), - const_cast(seq_lens_decoder.data()), - const_cast(max_dec_len.data()), - input_ids_stride, - max_model_len, - draft_tokens_stride, - max_batch_size, - max_ngram_size, - max_draft_tokens); -} - -PD_BUILD_STATIC_OP(ngram_match) - .Inputs({"input_ids", - "input_ids_len", - "token_ids_all", - "prompt_lens", - "step_idx", - "draft_token_num", - "draft_tokens", - "seq_lens_this_time", - "seq_lens_encoder", - "seq_lens_decoder", - "max_dec_len"}) - .Attrs({"max_ngram_size: int", "max_draft_tokens: int"}) - .Outputs({"draft_tokens_out", "seq_lens_this_time_out"}) - .SetKernelFn(PD_KERNEL(NgramMatch)) - .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, - {"seq_lens_this_time", "seq_lens_this_time_out"}}); diff --git a/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu new file mode 100644 index 00000000000..2f4904ee26c --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/ngram_match.cu @@ -0,0 +1,504 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include +#include +#include "paddle/extension.h" +#include "ngram_match_common.cuh" + +// ============================================================ +// Phase 1 search kernel — one block per batch item. +// Finds the leftmost ngram match and writes tentative draft +// tokens to a scratch buffer (draft_tokens_copy) along with +// the tentative new seq_lens_this_time to a copy buffer. +// Phase 2 will decide which ones to keep (threshold logic). +// ============================================================ +__global__ void ngram_match_search_kernel(const int64_t *input_ids, + const int64_t *input_ids_len, + const int64_t *token_ids_all, + const int64_t *prompt_lens, + const int64_t *step_idx, + const int *draft_token_num, + const int32_t *seq_lens_encoder, + const int32_t *seq_lens_decoder, + const int64_t *max_dec_len, + int64_t *draft_tokens_copy, + int32_t *seq_lens_this_time_copy, + int64_t input_ids_stride, + int64_t max_model_len, + int64_t draft_tokens_stride, + int64_t max_batch_size, + int max_ngram_size) { + int batch_idx = blockIdx.x; + if (batch_idx >= max_batch_size) return; + + __shared__ int64_t s_min_pos; + + if (threadIdx.x == 0) { + // Default: 0 = this item contributes nothing to threshold budget. + // Active decoder items will be set to 1+ below. + seq_lens_this_time_copy[batch_idx] = 0; + } + __syncthreads(); + + // Skip if encoder active (preserves original seq_lens_this_time) or + // decoder inactive (Phase 2 writes 0 for these). + if (seq_lens_encoder[batch_idx] > 0) return; + if (seq_lens_decoder[batch_idx] == 0) return; + + // Active decoder item: at least the base token. + if (threadIdx.x == 0) seq_lens_this_time_copy[batch_idx] = 1; + + const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; + const int64_t cur_input_ids_len = input_ids_len[batch_idx]; + const int64_t prompt_len = prompt_lens[batch_idx]; + const int64_t *cur_pre_ids = + token_ids_all + batch_idx * max_model_len + prompt_len; + const int64_t cur_step_idx = step_idx[batch_idx]; + + // Compute max_draft_tokens for this batch item + int64_t remaining = max_dec_len[batch_idx] - cur_step_idx - 1; + if (remaining <= 0) return; + int max_draft_tokens = static_cast( + min(static_cast(draft_token_num[batch_idx]), remaining)); + + for (int ngram_size = max_ngram_size; ngram_size >= 1; --ngram_size) { + if (cur_step_idx < ngram_size) continue; + + const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); + + int64_t pos = parallel_ngram_search( + cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos); + if (pos != INT64_MAX) { + int64_t start_idx = pos + ngram_size; + int64_t end_idx = min(start_idx + static_cast(max_draft_tokens), + cur_input_ids_len); + if (threadIdx.x == 0 && start_idx < end_idx) { + // Tentative token copy to scratch + int64_t n = end_idx - start_idx; + seq_lens_this_time_copy[batch_idx] = static_cast(1 + n); + int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride; + for (int64_t k = 0; k < n; k++) { + dst[1 + k] = cur_input_ids[start_idx + k]; + } + } + // Only early-exit when tokens were actually produced + if (start_idx < end_idx) { + return; + } + } + + pos = parallel_ngram_search( + cur_pre_ids, cur_step_idx, ngram, ngram_size, &s_min_pos); + if (pos != INT64_MAX) { + int64_t start_idx = pos + ngram_size; + int64_t end_idx = + min(start_idx + static_cast(max_draft_tokens), cur_step_idx); + if (threadIdx.x == 0 && start_idx < end_idx) { + // Tentative token copy to scratch + int64_t n = end_idx - start_idx; + seq_lens_this_time_copy[batch_idx] = static_cast(1 + n); + int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride; + for (int64_t k = 0; k < n; k++) { + dst[1 + k] = cur_pre_ids[start_idx + k]; + } + } + // Only early-exit when tokens were actually produced + if (start_idx < end_idx) { + return; + } + } + } +} + +// ============================================================ +// Phase 2 gather kernel — BlockScan threshold + copy +// <<<1, NGRAM_GATHER_THREADS>>> +// +// Reads tentative allocations from Phase 1 scratch buffers, +// computes prefix sums to enforce the global threshold, then +// writes final seq_lens_this_time and copies draft tokens. +// ============================================================ +__global__ void ngram_match_gather_kernel( + const int64_t *draft_tokens_copy, + const int32_t *seq_lens_this_time_copy, + const int32_t *seq_lens_encoder, + int64_t *draft_tokens, + int32_t *seq_lens_this_time, + int64_t draft_tokens_stride, + int64_t max_batch_size, + int threshold) { + typedef cub::BlockScan BlockScanInt; + __shared__ typename BlockScanInt::TempStorage temp_storage1; + __shared__ typename BlockScanInt::TempStorage temp_storage2; + __shared__ int s_total_active; + + int tid = threadIdx.x; + + // Load tentative values from Phase 1. + // Encoder-active items are included in the scan with their original + // seq_lens_this_time to match CPU threshold-budget accounting. + int tentative = 0; + int is_active = 0; + if (tid < max_batch_size) { + if (seq_lens_encoder[tid] > 0) { + // Encoder-active: contribute original token count to threshold budget. + // seq_lens_this_time[tid] is still unmodified at this point. + tentative = seq_lens_this_time[tid]; + is_active = 1; + } else { + tentative = seq_lens_this_time_copy[tid]; + is_active = (tentative > 0) ? 1 : 0; + } + } + + // Scan 1: inclusive prefix sum of tentative token counts + int token_prefix; + BlockScanInt(temp_storage1).InclusiveSum(tentative, token_prefix); + __syncthreads(); + + // Scan 2: inclusive prefix sum of active-item indicators + int active_prefix; + BlockScanInt(temp_storage2).InclusiveSum(is_active, active_prefix); + __syncthreads(); + + // Total active count from the last valid thread + if (tid == + min(static_cast(max_batch_size) - 1, NGRAM_GATHER_THREADS - 1)) { + s_total_active = active_prefix; + } + __syncthreads(); + + if (tid < max_batch_size) { + // Encoder-active items: preserve original seq_lens_this_time. + if (seq_lens_encoder[tid] > 0) return; + + if (tentative == 0) { + seq_lens_this_time[tid] = 0; + return; + } + + int exclusive_token_prefix = token_prefix - tentative; + int remaining_active = s_total_active - active_prefix; + + // Budget: total threshold minus tokens already allocated before me, + // minus at-least-1 reservation for every active item after me. + int budget = threshold - exclusive_token_prefix - remaining_active; + + int actual; + if (budget <= 1) { + actual = 1; // base token only + } else { + actual = min(tentative, budget); + } + + seq_lens_this_time[tid] = actual; + + // Copy draft tokens (slots 1..actual-1) from scratch to output + if (actual > 1) { + int64_t *dst = draft_tokens + tid * draft_tokens_stride; + const int64_t *src = draft_tokens_copy + tid * draft_tokens_stride; + for (int k = 1; k < actual; k++) { + dst[k] = src[k]; + } + } + } +} + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +// ============================================================ +// CPU path — preserved from original ngram_match.cc for +// backward compatibility with CPU-only callers and tests. +// ============================================================ +static int sum_cpu(const int *value, int num) { + int sum_value = 0; + for (int i = 0; i <= num; i++) { + sum_value += value[i]; + } + return sum_value; +} + +static void find_candidate_pred_tokens(const int64_t *input_ids, + const int64_t *input_ids_len, + const int64_t *token_ids_all, + const int64_t *prompt_lens, + const int64_t *step_idx, + const int *draft_token_num, + int64_t *draft_tokens, + int32_t *seq_lens_this_time, + int32_t *seq_lens_encoder, + int32_t *seq_lens_decoder, + int64_t *max_dec_len, + int64_t input_ids_stride, + int64_t max_model_len, + int64_t draft_tokens_stride, + int64_t max_batch_size, + int max_ngram_size = 3, + int max_draft_tokens = 10) { + int threshold = 128; + char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"); + if (env_var) { + threshold = std::stoi(env_var); + } + int unprocessed_batch_size = 0; + for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { + if (seq_lens_encoder[batch_idx] > 0 || seq_lens_decoder[batch_idx] > 0) { + unprocessed_batch_size++; + } + } + for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) { + max_draft_tokens = + std::min(static_cast(draft_token_num[batch_idx]), + max_dec_len[batch_idx] - step_idx[batch_idx] - 1); + if (seq_lens_encoder[batch_idx] > 0) { + continue; + } else if (seq_lens_decoder[batch_idx] == 0) { + seq_lens_this_time[batch_idx] = 0; + continue; + } + + const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride; + int64_t *cur_draft_tokens = draft_tokens + batch_idx * draft_tokens_stride; + const int64_t *cur_pre_ids = + token_ids_all + batch_idx * max_model_len + prompt_lens[batch_idx]; + const int64_t cur_step_idx = step_idx[batch_idx]; + const int64_t cur_input_ids_len = input_ids_len[batch_idx]; + seq_lens_this_time[batch_idx] = 1; + unprocessed_batch_size--; + + auto sum_token_num = sum_cpu(seq_lens_this_time, batch_idx); + int left_min_token_num = unprocessed_batch_size; + + if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) { + int tmp_max_draft_tokens = threshold - sum_token_num - left_min_token_num; + max_draft_tokens = tmp_max_draft_tokens < max_draft_tokens + ? tmp_max_draft_tokens + : max_draft_tokens; + } + + if (sum_token_num + left_min_token_num >= threshold - 1) { + continue; + } + + for (int ngram_size = max_ngram_size; ngram_size > 0; --ngram_size) { + if (cur_step_idx < ngram_size) { + continue; + } + const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size); + + bool match_input = false; + for (int64_t i = 0; i <= cur_input_ids_len - ngram_size; ++i) { + bool match = true; + for (int j = 0; j < ngram_size; j++) { + if (ngram[j] != cur_input_ids[i + j]) { + match = false; + break; + } + } + if (match) { + int64_t start_idx = i + ngram_size; + int64_t end_idx = + std::min(start_idx + max_draft_tokens, cur_input_ids_len); + if (start_idx >= end_idx) continue; + + int64_t cur_draft_token_num = end_idx - start_idx; + seq_lens_this_time[batch_idx] = cur_draft_token_num + 1; + memcpy(cur_draft_tokens + 1, + cur_input_ids + start_idx, + sizeof(int64_t) * cur_draft_token_num); + ngram_size = 0; + match_input = true; + break; + } + } + if (!match_input) { + for (int64_t i = 0; i <= cur_step_idx - ngram_size; ++i) { + bool match = true; + for (int j = 0; j < ngram_size; j++) { + if (ngram[j] != cur_pre_ids[i + j]) { + match = false; + break; + } + } + if (match) { + int64_t start_idx = i + ngram_size; + int64_t end_idx = + std::min(start_idx + max_draft_tokens, cur_step_idx); + int64_t cur_draft_token_num = end_idx - start_idx; + if (start_idx >= end_idx) continue; + + seq_lens_this_time[batch_idx] = cur_draft_token_num + 1; + memcpy(cur_draft_tokens + 1, + cur_pre_ids + start_idx, + sizeof(int64_t) * cur_draft_token_num); + ngram_size = 0; + break; + } + } + } + } + } +} + +// ============================================================ +// GPU path — Two-phase parallel CUDA kernels for ngram matching. +// +// Phase 1: <<>> — parallel sliding-window +// search within each batch item (NGRAM_BLOCK_THREADS threads +// per block). Also copies matched draft tokens to scratch. +// Phase 2: <<<1, NGRAM_GATHER_THREADS>>> — CUB BlockScan prefix-sum +// threshold enforcement + final token copy. +// +// Phase 1 is O(bsz × seq_len × ngram_size) distributed across +// bsz × NGRAM_BLOCK_THREADS threads. Phase 2 is O(bsz) with scans. +// ============================================================ + +void NgramMatch(const paddle::Tensor &input_ids, + const paddle::Tensor &input_ids_len, + const paddle::Tensor &token_ids_all, + const paddle::Tensor &prompt_lens, + const paddle::Tensor &step_idx, + const paddle::Tensor &draft_token_num, + const paddle::Tensor &draft_tokens, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &max_dec_len, + const int max_ngram_size, + const int max_draft_tokens) { + auto input_ids_shape = input_ids.shape(); + const int64_t input_ids_stride = input_ids_shape[1]; + + const int64_t max_model_len = token_ids_all.shape()[1]; + + auto draft_tokens_shape = draft_tokens.shape(); + const int64_t draft_tokens_stride = draft_tokens_shape[1]; + + const int64_t max_batch_size = seq_lens_this_time.shape()[0]; + + int threshold = 128; + const char *env_var = getenv("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"); + if (env_var) { + threshold = std::stoi(env_var); + } + + if (input_ids.is_gpu()) { + auto stream = input_ids.stream(); + + // Persistent scratch buffers for Phase 1 → Phase 2 communication. + // Cached across calls to avoid per-invocation allocation overhead. + // Write-before-read pattern (Phase 1 writes all elements before + // Phase 2 reads) means no initialization is needed between calls. + // Safety: single-threaded Python caller + CUDA stream serialization. + static paddle::Tensor s_draft_copy; + static paddle::Tensor s_seqlens_copy; + static int64_t s_scratch_batch = 0; + static int64_t s_scratch_stride = 0; + + if (max_batch_size > s_scratch_batch || + draft_tokens_stride > s_scratch_stride) { + s_draft_copy = paddle::empty({max_batch_size, draft_tokens_stride}, + paddle::DataType::INT64, + input_ids.place()); + s_seqlens_copy = paddle::empty( + {max_batch_size}, paddle::DataType::INT32, input_ids.place()); + s_scratch_batch = max_batch_size; + s_scratch_stride = draft_tokens_stride; + } + auto &draft_tokens_copy = s_draft_copy; + auto &seq_lens_this_time_copy = s_seqlens_copy; + + // Fail-fast: BlockScan Phase 2 requires max_batch_size ≤ block size. + PD_CHECK(max_batch_size <= NGRAM_GATHER_THREADS, + "ngram_match: max_batch_size exceeds NGRAM_GATHER_THREADS"); + + // Phase 1: parallel search — one block per batch item. + // Also copies matched tokens to scratch and writes tentative seq_lens. + ngram_match_search_kernel<<>>( + input_ids.data(), + input_ids_len.data(), + token_ids_all.data(), + prompt_lens.data(), + step_idx.data(), + draft_token_num.data(), + seq_lens_encoder.data(), + seq_lens_decoder.data(), + max_dec_len.data(), + draft_tokens_copy.data(), + seq_lens_this_time_copy.data(), + input_ids_stride, + max_model_len, + draft_tokens_stride, + max_batch_size, + max_ngram_size); + + // Phase 2: BlockScan threshold enforcement + final token copy. + // <<<1, NGRAM_GATHER_THREADS>>> — all batch items handled by one block. + ngram_match_gather_kernel<<<1, NGRAM_GATHER_THREADS, 0, stream>>>( + draft_tokens_copy.data(), + seq_lens_this_time_copy.data(), + seq_lens_encoder.data(), + const_cast(draft_tokens.data()), + const_cast(seq_lens_this_time.data()), + draft_tokens_stride, + max_batch_size, + threshold); + } else { + find_candidate_pred_tokens( + input_ids.data(), + input_ids_len.data(), + token_ids_all.data(), + prompt_lens.data(), + step_idx.data(), + draft_token_num.data(), + const_cast(draft_tokens.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(max_dec_len.data()), + input_ids_stride, + max_model_len, + draft_tokens_stride, + max_batch_size, + max_ngram_size, + max_draft_tokens); + } +} + +PD_BUILD_STATIC_OP(ngram_match) + .Inputs({"input_ids", + "input_ids_len", + "token_ids_all", + "prompt_lens", + "step_idx", + "draft_token_num", + "draft_tokens", + "seq_lens_this_time", + "seq_lens_encoder", + "seq_lens_decoder", + "max_dec_len"}) + .Attrs({"max_ngram_size: int", "max_draft_tokens: int"}) + .Outputs({"draft_tokens_out", "seq_lens_this_time_out"}) + .SetKernelFn(PD_KERNEL(NgramMatch)) + .SetInplaceMap({{"draft_tokens", "draft_tokens_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}}); diff --git a/custom_ops/gpu_ops/speculate_decoding/ngram_match_common.cuh b/custom_ops/gpu_ops/speculate_decoding/ngram_match_common.cuh new file mode 100644 index 00000000000..af096b72481 --- /dev/null +++ b/custom_ops/gpu_ops/speculate_decoding/ngram_match_common.cuh @@ -0,0 +1,151 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +// Shared ngram matching logic used by both ngram_match_kernel and +// ngram_match_mixed_kernel. Extracted per upstream requirement: +// "两个Kernel逻辑有较为相似部分,Kernel 形式为提取共用的匹配逻辑,外加业务逻辑" +// +// Two-phase parallel architecture: +// Phase 1 — <<>>: parallel sliding-window +// search + tentative token copy (one block per batch item). +// Phase 2 — <<<1, NGRAM_GATHER_THREADS>>>: parallel threshold truncation +// via CUB BlockScan prefix-sum, then copy winners to output + +#define NGRAM_BLOCK_THREADS 1024 +#define NGRAM_GATHER_THREADS 1024 + +// ------------------------------------------------------------ +// atomicMin for int64_t via CAS loop. CUDA has no native +// int64 atomicMin. All values are non-negative positions or +// INT64_MAX, so unsigned reinterpretation is safe. +// ------------------------------------------------------------ +__device__ __forceinline__ void atomicMin64(int64_t *addr, int64_t val) { + unsigned long long *addr_ull = reinterpret_cast(addr); + unsigned long long val_ull = static_cast(val); + // Non-atomic initial read is intentional: the CAS loop below detects and + // retries on any stale value, so a torn read here is harmless. + unsigned long long old = *addr_ull; + while (val_ull < old) { + unsigned long long assumed = old; + old = atomicCAS(addr_ull, assumed, val_ull); + if (old == assumed) break; + } +} + +// ------------------------------------------------------------ +// parallel_ngram_search — Block-cooperative haystack search. +// +// Template-specialized for common ngram sizes (1-3) to enable: +// - Register caching of ngram tokens (avoid repeated global loads) +// - Full compile-time unrolling of inner comparison loop +// - __restrict__ hints for pointer non-aliasing optimization +// +// Runtime dispatcher preserves the original call signature so both +// ngram_match.cu and ngram_match_mixed.cu work transparently. +// +// Early-exit (A2): once a match is found (s_min_pos < INT64_MAX), +// threads that are past the current best skip remaining work. +// +// Returns the leftmost match position, or INT64_MAX if no match. +// Caller must provide __shared__ int64_t s_min_pos. +// ------------------------------------------------------------ +template +__device__ __forceinline__ int64_t +parallel_ngram_search_specialized(const int64_t *__restrict__ haystack, + int64_t haystack_len, + const int64_t *__restrict__ ngram, + int64_t *s_min_pos) { + int tid = threadIdx.x; + int nthreads = blockDim.x; + + if (tid == 0) *s_min_pos = INT64_MAX; + __syncthreads(); + + int64_t search_len = haystack_len - NGRAM_SIZE + 1; + if (search_len <= 0) { + __syncthreads(); + return *s_min_pos; + } + + // Cache ngram tokens in registers — eliminates repeated global reads. + int64_t ng[NGRAM_SIZE]; +#pragma unroll + for (int j = 0; j < NGRAM_SIZE; j++) ng[j] = ngram[j]; + + for (int64_t i = tid; i < search_len; i += nthreads) { + // A2: Early-exit — skip positions beyond current best match. + if (i > *s_min_pos) break; + + bool match = true; +#pragma unroll + for (int j = 0; j < NGRAM_SIZE; j++) { + if (ng[j] != haystack[i + j]) { + match = false; + break; + } + } + if (match) atomicMin64(s_min_pos, i); + } + __syncthreads(); + return *s_min_pos; +} + +// Runtime dispatcher — same signature as original, transparent to callers. +__device__ __forceinline__ int64_t +parallel_ngram_search(const int64_t *__restrict__ haystack, + int64_t haystack_len, + const int64_t *__restrict__ ngram, + int ngram_size, + int64_t *s_min_pos) { + switch (ngram_size) { + case 1: + return parallel_ngram_search_specialized<1>( + haystack, haystack_len, ngram, s_min_pos); + case 2: + return parallel_ngram_search_specialized<2>( + haystack, haystack_len, ngram, s_min_pos); + case 3: + return parallel_ngram_search_specialized<3>( + haystack, haystack_len, ngram, s_min_pos); + default: + break; + } + // Fallback for ngram_size > 3 — runtime loop, no unrolling. + int tid = threadIdx.x; + int nthreads = blockDim.x; + if (tid == 0) *s_min_pos = INT64_MAX; + __syncthreads(); + int64_t search_len = haystack_len - ngram_size + 1; + if (search_len <= 0) { + __syncthreads(); + return *s_min_pos; + } + for (int64_t i = tid; i < search_len; i += nthreads) { + if (i > *s_min_pos) break; + bool match = true; + for (int j = 0; j < ngram_size; j++) { + if (ngram[j] != haystack[i + j]) { + match = false; + break; + } + } + if (match) atomicMin64(s_min_pos, i); + } + __syncthreads(); + return *s_min_pos; +} diff --git a/fastdeploy/spec_decode/mtp.py b/fastdeploy/spec_decode/mtp.py index 5b62985e92e..9a5d3fa4585 100644 --- a/fastdeploy/spec_decode/mtp.py +++ b/fastdeploy/spec_decode/mtp.py @@ -1176,28 +1176,20 @@ def _update_status(self): ) def _extend_draft_token_with_ngram_match(self): - # TODO(liuzichang): Optimize this Kernel to CUDA Kernel to reduce lantency - device = paddle.CUDAPinnedPlace() - - draft_tokens = self.target_model_inputs["draft_tokens"].cpu() - seq_lens_this_time = self.target_model_inputs["seq_lens_this_time"].cpu() - seq_lens_decoder = self.model_inputs["seq_lens_decoder"].cpu() hybrid_mtp_ngram( - self.model_inputs["input_ids_cpu"], - self.model_inputs["input_ids_len"], - self.model_inputs["pre_ids"]._copy_to(device, True), - self.model_inputs["step_idx"].cpu(), - self.target_model_inputs["actual_draft_token_num"].cpu(), - draft_tokens, - seq_lens_this_time, - seq_lens_decoder, - self.model_inputs["max_dec_len"].cpu(), + self.model_inputs["input_ids_cpu"].cuda(), + self.model_inputs["input_ids_len"].cuda(), + self.model_inputs["pre_ids"], + self.model_inputs["step_idx"], + self.target_model_inputs["actual_draft_token_num"], + self.target_model_inputs["draft_tokens"], + self.target_model_inputs["seq_lens_this_time"], + self.model_inputs["seq_lens_decoder"], + self.model_inputs["max_dec_len"], self.max_ngram_size, self.min_ngram_size, self.max_draft_token_num, ) - self.target_model_inputs["draft_tokens"][:] = draft_tokens.cuda() - self.target_model_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() def _run_impl( self, full_hidden_states: paddle.Tensor, step_use_cudagraph: bool = False, is_dummy_run: bool = False diff --git a/fastdeploy/spec_decode/ngram.py b/fastdeploy/spec_decode/ngram.py index b64e8fb5790..2de823b36da 100644 --- a/fastdeploy/spec_decode/ngram.py +++ b/fastdeploy/spec_decode/ngram.py @@ -37,37 +37,31 @@ def __init__(self, fd_config: "FDConfig"): super().__init__(fd_config) self.max_ngram_size = self.speculative_config.max_ngram_size self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu() + self.input_ids_len_gpu = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cuda() def update(self, bid: int, seq_len: int): """ update """ self.input_ids_len[bid] = seq_len + self.input_ids_len_gpu[bid] = seq_len def _run_impl(self, share_inputs): """ run """ - draft_tokens = share_inputs["draft_tokens"].cpu() - seq_lens_this_time = share_inputs["seq_lens_this_time"].cpu() - seq_lens_encoder = share_inputs["seq_lens_encoder"].cpu() - seq_lens_decoder = share_inputs["seq_lens_decoder"].cpu() - ngram_match( - share_inputs["input_ids_cpu"], - self.input_ids_len.cpu(), - share_inputs["token_ids_all"].cpu(), - share_inputs["prompt_lens"].cpu(), - share_inputs["step_idx"].cpu(), - share_inputs["actual_draft_token_num"].cpu(), - draft_tokens, - seq_lens_this_time, - seq_lens_encoder, - seq_lens_decoder, - share_inputs["max_dec_len"].cpu(), + share_inputs["input_ids_cpu"].cuda(), + self.input_ids_len_gpu, + share_inputs["token_ids_all"], + share_inputs["prompt_lens"], + share_inputs["step_idx"], + share_inputs["actual_draft_token_num"], + share_inputs["draft_tokens"], + share_inputs["seq_lens_this_time"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_decoder"], + share_inputs["max_dec_len"], self.max_ngram_size, self.max_draft_token_num, ) - share_inputs["draft_tokens"][:] = draft_tokens.cuda() - share_inputs["seq_lens_encoder"][:] = seq_lens_encoder.cuda() - share_inputs["seq_lens_this_time"][:] = seq_lens_this_time.cuda() diff --git a/tests/spec_decode/test_benchmark_ngram_kernel.py b/tests/spec_decode/test_benchmark_ngram_kernel.py new file mode 100644 index 00000000000..6fb13be7d13 --- /dev/null +++ b/tests/spec_decode/test_benchmark_ngram_kernel.py @@ -0,0 +1,358 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Multi-dimension benchmark for ngram_match GPU kernel vs CPU copy path. + +Matches NKNaN's profiling methodology (5 experiment groups) using +FastDeploy's native ngram_match op interface. + +Groups: + 1. seq_len — [1024, 4096, 16384, 65536, 131072] + 2. batch_size — [1, 8, 32, 128, 512] + 3. ngram hit — [high_input, high_pre, low_input, low_pre, none] + 4. threshold — [16, 32, 64, 128, 256] + 5. threshold × batch (batch=128) + +Run: + cd FastDeploy && python tests/spec_decode/test_benchmark_ngram_kernel.py +""" +import os +import sys +import time +import unittest + +import numpy as np +import paddle + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + +MAX_NGRAM_SIZE = 3 +MAX_DRAFT_TOKENS = 10 +NUM_ITERS = 1000 +WARMUP = 5 + + +def _build_data(batch_size, seq_len, hit_type="low_input", seed=42): + """ + Build test tensors with controlled ngram hit placement. + + hit_type controls where the ngram match is found: + - high_input: match near start of input_ids (fast find) + - high_pre: match near start of token_ids_all gen tokens + - low_input: match near end of input_ids (worst-case scan) + - low_pre: match near end of token_ids_all gen tokens + - none: no planted match (full scan, no hit) + """ + rng = np.random.RandomState(seed) + step_idx_val = max(MAX_NGRAM_SIZE + 2, 20) + pre_len = step_idx_val + 1 + max_model_len = max(seq_len + 64, pre_len + 64) + + input_ids = rng.randint(10, 500, (batch_size, seq_len)).astype(np.int64) + token_ids_all = rng.randint(10, 500, (batch_size, max_model_len)).astype(np.int64) + pattern = np.arange(1001, 1001 + MAX_NGRAM_SIZE, dtype=np.int64) + + for b in range(batch_size): + # Plant pattern in token_ids_all at step_idx alignment (the ngram to search for) + ng_start = step_idx_val + 1 - MAX_NGRAM_SIZE + token_ids_all[b, ng_start : step_idx_val + 1] = pattern + + if hit_type == "high_input": + pos = 5 + if pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS <= seq_len: + input_ids[b, pos : pos + MAX_NGRAM_SIZE] = pattern + input_ids[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( + 2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64 + ) + + elif hit_type == "high_pre": + pos = 5 + if pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS < ng_start: + token_ids_all[b, pos : pos + MAX_NGRAM_SIZE] = pattern + token_ids_all[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( + 2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64 + ) + + elif hit_type == "low_input": + pos = seq_len - MAX_NGRAM_SIZE - MAX_DRAFT_TOKENS - 5 + if pos > 0: + input_ids[b, pos : pos + MAX_NGRAM_SIZE] = pattern + input_ids[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( + 2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64 + ) + + elif hit_type == "low_pre": + pos = step_idx_val - MAX_NGRAM_SIZE - MAX_DRAFT_TOKENS - 5 + if pos > 0 and pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS < ng_start: + token_ids_all[b, pos : pos + MAX_NGRAM_SIZE] = pattern + token_ids_all[b, pos + MAX_NGRAM_SIZE : pos + MAX_NGRAM_SIZE + MAX_DRAFT_TOKENS] = np.arange( + 2001, 2001 + MAX_DRAFT_TOKENS, dtype=np.int64 + ) + + elif hit_type == "none": + pass # No match planted — random data only + + input_ids_len = np.full((batch_size, 1), seq_len, dtype=np.int64) + prompt_lens = np.zeros((batch_size, 1), dtype=np.int64) + step_idx = np.full((batch_size, 1), step_idx_val, dtype=np.int64) + draft_token_num = np.full((batch_size, 1), MAX_DRAFT_TOKENS, dtype=np.int32) + draft_tokens = np.zeros((batch_size, MAX_DRAFT_TOKENS + 1), dtype=np.int64) + seq_lens_this_time = np.ones(batch_size, dtype=np.int32) + seq_lens_encoder = np.zeros(batch_size, dtype=np.int32) + seq_lens_decoder = np.ones(batch_size, dtype=np.int32) + max_dec_len = np.full((batch_size, 1), 1048576, dtype=np.int64) + + return { + "input_ids": input_ids, + "input_ids_len": input_ids_len, + "token_ids_all": token_ids_all, + "prompt_lens": prompt_lens, + "step_idx": step_idx, + "draft_token_num": draft_token_num, + "draft_tokens": draft_tokens, + "seq_lens_this_time": seq_lens_this_time, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "max_dec_len": max_dec_len, + } + + +def _to_gpu(np_dict): + out = {} + for k, v in np_dict.items(): + out[k] = paddle.to_tensor(v, place=paddle.CUDAPlace(0)) + return out + + +def _run_gpu(ngram_match_fn, gpu_data): + """Run GPU kernel (tensors already on GPU).""" + ngram_match_fn( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + MAX_NGRAM_SIZE, + MAX_DRAFT_TOKENS, + ) + + +def _time_gpu(ngram_match_fn, batch_size, seq_len, hit_type, n_runs): + """Time GPU kernel with pre-created tensors (no data creation in loop).""" + gpu_data = _to_gpu(_build_data(batch_size, seq_len, hit_type)) + # Pre-allocate mutable output buffers once — avoids per-iteration + # paddle.zeros/ones which add ~20-40µs allocation + fill overhead. + draft_buf = paddle.zeros([batch_size, MAX_DRAFT_TOKENS + 1], dtype="int64").cuda() + seqlens_buf = paddle.ones([batch_size], dtype="int32").cuda() + # Warmup + for _ in range(WARMUP): + seqlens_buf.fill_(1) + gpu_data["draft_tokens"] = draft_buf + gpu_data["seq_lens_this_time"] = seqlens_buf + _run_gpu(ngram_match_fn, gpu_data) + paddle.device.synchronize() + + paddle.device.synchronize() + t0 = time.perf_counter() + for _ in range(n_runs): + seqlens_buf.fill_(1) + gpu_data["draft_tokens"] = draft_buf + gpu_data["seq_lens_this_time"] = seqlens_buf + _run_gpu(ngram_match_fn, gpu_data) + paddle.device.synchronize() + return (time.perf_counter() - t0) / n_runs * 1e6 # microseconds + + +def _time_cpu_copy(batch_size, seq_len, hit_type, n_runs): + """Time the old CPU-copy path: GPU→CPU transfer + CPU→GPU transfer back.""" + gpu_data = _to_gpu(_build_data(batch_size, seq_len, hit_type)) + # Warmup + for _ in range(WARMUP): + _ = {k: v.cpu() for k, v in gpu_data.items()} + paddle.device.synchronize() + + paddle.device.synchronize() + t0 = time.perf_counter() + for _ in range(n_runs): + cpu_copy = {k: v.cpu() for k, v in gpu_data.items()} + _ = cpu_copy["draft_tokens"].cuda() + _ = cpu_copy["seq_lens_this_time"].cuda() + paddle.device.synchronize() + return (time.perf_counter() - t0) / n_runs * 1e6 # microseconds + + +def _print_table(title, header, rows): + """Print formatted benchmark table.""" + print(f"\n{'=' * 80}") + print(title) + print(f"{'─' * 80}") + print(header) + print(f"{'─' * 80}") + for row in rows: + print(row) + print(f"{'=' * 80}") + + +class TestNgramBenchmarkGroups(unittest.TestCase): + """Multi-dimension benchmark matching NKNaN's 5-group methodology.""" + + @classmethod + def setUpClass(cls): + if not paddle.is_compiled_with_cuda(): + raise unittest.SkipTest("CUDA not available") + paddle.set_device("gpu") + try: + from fastdeploy.model_executor.ops.gpu import ngram_match + + cls.ngram_match = staticmethod(ngram_match) + except Exception as e: + raise unittest.SkipTest(f"Cannot import ngram_match op: {e}") + + def test_group1_seq_len(self): + """Group 1: Vary seq_len with fixed batch=16, threshold=512, hit=low_input.""" + seq_lens = [1024, 4096, 16384, 65536, 131072] + batch_size = 16 + hit_type = "low_input" + old_env = os.environ.get("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD") + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = "512" + try: + rows = [] + for sl in seq_lens: + gpu_us = _time_gpu(self.ngram_match, batch_size, sl, hit_type, NUM_ITERS) + cpu_us = _time_cpu_copy(batch_size, sl, hit_type, NUM_ITERS) + speedup = cpu_us / gpu_us if gpu_us > 0 else 0 + rows.append(f"{sl:>8} {gpu_us:>12.1f} {cpu_us:>12.1f} {speedup:>8.2f}x") + _print_table( + f"Group 1: seq_len (batch={batch_size}, threshold=512, hit={hit_type}, {NUM_ITERS} runs)", + f"{'seq_len':>8} {'GPU (µs)':>12} {'CPU copy (µs)':>12} {'Speedup':>8}", + rows, + ) + finally: + if old_env is None: + os.environ.pop("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD", None) + else: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = old_env + + def test_group2_batch_size(self): + """Group 2: Vary batch_size with fixed seq_len=16384, threshold=8192, hit=low_input.""" + batch_sizes = [1, 8, 32, 128, 512] + seq_len = 16384 + hit_type = "low_input" + old_env = os.environ.get("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD") + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = "8192" + try: + rows = [] + for bsz in batch_sizes: + gpu_us = _time_gpu(self.ngram_match, bsz, seq_len, hit_type, NUM_ITERS) + cpu_us = _time_cpu_copy(bsz, seq_len, hit_type, NUM_ITERS) + speedup = cpu_us / gpu_us if gpu_us > 0 else 0 + rows.append(f"{bsz:>8} {gpu_us:>12.1f} {cpu_us:>12.1f} {speedup:>8.2f}x") + _print_table( + f"Group 2: batch_size (seq_len={seq_len}, threshold=8192, hit={hit_type}, {NUM_ITERS} runs)", + f"{'batch':>8} {'GPU (µs)':>12} {'CPU copy (µs)':>12} {'Speedup':>8}", + rows, + ) + finally: + if old_env is None: + os.environ.pop("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD", None) + else: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = old_env + + def test_group3_ngram_hit(self): + """Group 3: Vary hit pattern with fixed batch=16, seq_len=32768, threshold=512.""" + hit_types = ["high_input", "high_pre", "low_input", "low_pre", "none"] + batch_size = 16 + seq_len = 32768 + old_env = os.environ.get("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD") + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = "512" + try: + rows = [] + for ht in hit_types: + gpu_us = _time_gpu(self.ngram_match, batch_size, seq_len, ht, NUM_ITERS) + cpu_us = _time_cpu_copy(batch_size, seq_len, ht, NUM_ITERS) + speedup = cpu_us / gpu_us if gpu_us > 0 else 0 + rows.append(f"{ht:>12} {gpu_us:>12.1f} {cpu_us:>12.1f} {speedup:>8.2f}x") + _print_table( + f"Group 3: ngram hit (batch={batch_size}, seq_len={seq_len}, threshold=512, {NUM_ITERS} runs)", + f"{'hit_type':>12} {'GPU (µs)':>12} {'CPU copy (µs)':>12} {'Speedup':>8}", + rows, + ) + finally: + if old_env is None: + os.environ.pop("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD", None) + else: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = old_env + + def test_group4_threshold(self): + """Group 4: Vary threshold with fixed batch=8, seq_len=32768, hit=low_input.""" + thresholds = [16, 32, 64, 128, 256] + batch_size = 8 + seq_len = 32768 + hit_type = "low_input" + rows = [] + old_env = os.environ.get("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD") + try: + for thr in thresholds: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = str(thr) + gpu_us = _time_gpu(self.ngram_match, batch_size, seq_len, hit_type, NUM_ITERS) + cpu_us = _time_cpu_copy(batch_size, seq_len, hit_type, NUM_ITERS) + speedup = cpu_us / gpu_us if gpu_us > 0 else 0 + rows.append(f"{thr:>8} {gpu_us:>12.1f} {cpu_us:>12.1f} {speedup:>8.2f}x") + _print_table( + f"Group 4: threshold (batch={batch_size}, seq_len={seq_len}, hit={hit_type}, {NUM_ITERS} runs)", + f"{'thresh':>8} {'GPU (µs)':>12} {'CPU copy (µs)':>12} {'Speedup':>8}", + rows, + ) + finally: + if old_env is None: + os.environ.pop("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD", None) + else: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = old_env + + def test_group5_threshold_x_batch(self): + """Group 5: Vary threshold with large batch=128 to expose truncation effects.""" + thresholds = [16, 32, 64, 128, 256] + batch_size = 128 + seq_len = 32768 + hit_type = "low_input" + rows = [] + old_env = os.environ.get("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD") + try: + for thr in thresholds: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = str(thr) + gpu_us = _time_gpu(self.ngram_match, batch_size, seq_len, hit_type, NUM_ITERS) + cpu_us = _time_cpu_copy(batch_size, seq_len, hit_type, NUM_ITERS) + speedup = cpu_us / gpu_us if gpu_us > 0 else 0 + rows.append(f"{thr:>8} {gpu_us:>12.1f} {cpu_us:>12.1f} {speedup:>8.2f}x") + _print_table( + f"Group 5: threshold×batch (batch={batch_size}, seq_len={seq_len}, hit={hit_type}, {NUM_ITERS} runs)", + f"{'thresh':>8} {'GPU (µs)':>12} {'CPU copy (µs)':>12} {'Speedup':>8}", + rows, + ) + finally: + if old_env is None: + os.environ.pop("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD", None) + else: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = old_env + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/spec_decode/test_ngram_gpu_kernel.py b/tests/spec_decode/test_ngram_gpu_kernel.py new file mode 100644 index 00000000000..f4b5be185ac --- /dev/null +++ b/tests/spec_decode/test_ngram_gpu_kernel.py @@ -0,0 +1,1054 @@ +#!/usr/bin/env python3 +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Correctness + latency test for GPU ngram_match & hybrid_mtp_ngram kernels. + +Run on AI Studio V100: + cd FastDeploy && pip install -e . && python tests/spec_decode/test_ngram_gpu_kernel.py + +Or standalone (compile custom ops first): + bash build.sh 0 && python tests/spec_decode/test_ngram_gpu_kernel.py +""" +import os +import sys +import time +import unittest + +import numpy as np +import paddle + +# Ensure FastDeploy ops are importable +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../..")) + + +def _cpu_ngram_match( + input_ids, + input_ids_len, + token_ids_all, + prompt_lens, + step_idx, + draft_token_num, + draft_tokens, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + max_dec_len, + max_ngram_size, + max_draft_tokens_param, + threshold=128, +): + """Pure NumPy reference matching the original ngram_match.cc logic.""" + # Flatten (N,1) shaped arrays to 1D for scalar indexing + max_dec_len = max_dec_len.ravel() + step_idx = step_idx.ravel() + draft_token_num = draft_token_num.ravel() + prompt_lens = prompt_lens.ravel() + input_ids_len = input_ids_len.ravel() + max_batch_size = seq_lens_this_time.shape[0] + + unprocessed = sum(1 for b in range(max_batch_size) if seq_lens_encoder[b] > 0 or seq_lens_decoder[b] > 0) + + for batch_idx in range(max_batch_size): + remaining = int(max_dec_len[batch_idx] - step_idx[batch_idx] - 1) + mdt = min(int(draft_token_num[batch_idx]), remaining) + + if seq_lens_encoder[batch_idx] > 0: + continue + elif seq_lens_decoder[batch_idx] == 0: + seq_lens_this_time[batch_idx] = 0 + continue + + cur_input_ids = input_ids[batch_idx] + cur_draft = draft_tokens[batch_idx] + prompt_len = int(prompt_lens[batch_idx]) + cur_pre_ids = token_ids_all[batch_idx, prompt_len:] + cur_step = int(step_idx[batch_idx]) + cur_ids_len = int(input_ids_len[batch_idx]) + seq_lens_this_time[batch_idx] = 1 + unprocessed -= 1 + + sum_tok = sum(int(seq_lens_this_time[i]) for i in range(batch_idx + 1)) + left_min = unprocessed + + if sum_tok + mdt + left_min > threshold: + mdt = min(mdt, threshold - sum_tok - left_min) + if sum_tok + left_min >= threshold - 1: + continue + + for ngram_size in range(max_ngram_size, 0, -1): + if cur_step < ngram_size: + continue + ngram = cur_pre_ids[cur_step + 1 - ngram_size : cur_step + 1] + + # Search in input_ids + match_input = False + for i in range(cur_ids_len - ngram_size + 1): + if np.array_equal(cur_input_ids[i : i + ngram_size], ngram): + start = i + ngram_size + end = min(start + mdt, cur_ids_len) + if start >= end: + continue + n = end - start + seq_lens_this_time[batch_idx] = n + 1 + cur_draft[1 : 1 + n] = cur_input_ids[start : start + n] + match_input = True + break + if match_input: + break + + # Search in pre_ids + found = False + for i in range(cur_step - ngram_size + 1): + if np.array_equal(cur_pre_ids[i : i + ngram_size], ngram): + start = i + ngram_size + end = min(start + mdt, cur_step) + if start >= end: + continue + n = end - start + seq_lens_this_time[batch_idx] = n + 1 + cur_draft[1 : 1 + n] = cur_pre_ids[start : start + n] + found = True + break + if found: + break + + +def _cpu_hybrid_mtp_ngram( + input_ids, + input_ids_len, + pre_ids, + step_idx, + draft_token_num, + draft_tokens, + seq_lens_this_time, + seq_lens_decoder, + max_dec_len, + max_ngram_size, + min_ngram_size, + max_draft_tokens_param, + threshold=1024, +): + """Pure NumPy reference matching the original ngram_match_mixed.cu CPU logic.""" + # Flatten (N,1) shaped arrays to 1D for scalar indexing + max_dec_len = max_dec_len.ravel() + step_idx = step_idx.ravel() + draft_token_num = draft_token_num.ravel() + input_ids_len = input_ids_len.ravel() + max_batch_size = seq_lens_this_time.shape[0] + + unprocessed = sum(1 for b in range(max_batch_size) if seq_lens_decoder[b] > 0) + + for batch_idx in range(max_batch_size): + ori_slt = int(seq_lens_this_time[batch_idx]) + remaining = int(max_dec_len[batch_idx] - step_idx[batch_idx] - 1) + max_q = min(max_draft_tokens_param - ori_slt + 1, remaining) + + if ori_slt == 0 or max_q <= 0: + continue + + cur_input_ids = input_ids[batch_idx] + cur_draft = draft_tokens[batch_idx] + cur_pre = pre_ids[batch_idx] + cur_step = int(step_idx[batch_idx]) + cur_ids_len = int(input_ids_len[batch_idx]) + unprocessed -= 1 + + sum_tok = sum(int(seq_lens_this_time[i]) for i in range(batch_idx + 1)) + left_min = unprocessed + + if sum_tok + max_q + left_min > threshold: + max_q = min(max_q, threshold - sum_tok - left_min) + if sum_tok + left_min >= threshold - 1: + continue + + match_global = False + for ngram_size in range(max_ngram_size, min_ngram_size - 1, -1): + if match_global: + break + if cur_step < ngram_size: + continue + ngram = cur_pre[cur_step + 1 - ngram_size : cur_step + 1] + + # Search in input_ids + for i in range(cur_ids_len - ngram_size + 1): + if match_global: + break + if np.array_equal(cur_input_ids[i : i + ngram_size], ngram): + start = i + ngram_size + end = min(start + max_q, cur_ids_len) + if start >= end: + continue + n = end - start + seq_lens_this_time[batch_idx] = ori_slt + n + cur_draft[ori_slt : ori_slt + n] = cur_input_ids[start : start + n] + match_global = True + + # Search in pre_ids + if not match_global: + for i in range(cur_step - ngram_size + 1): + if match_global: + break + if np.array_equal(cur_pre[i : i + ngram_size], ngram): + start = i + ngram_size + end = min(start + max_q, cur_step) + if start >= end: + continue + n = end - start + seq_lens_this_time[batch_idx] = ori_slt + n + cur_draft[ori_slt : ori_slt + n] = cur_pre[start : start + n] + match_global = True + + +def _make_ngram_test_data(batch_size=4, input_len=64, max_model_len=256, max_draft=10, seed=42): + """Create realistic test tensors for ngram_match op.""" + rng = np.random.RandomState(seed) + vocab_size = 1000 + # Ensure max_model_len can hold prompt + generated tokens + max_model_len = max(max_model_len, input_len + 64) + + # Create prompt tokens with repeating patterns to ensure ngram matches + input_ids = rng.randint(0, vocab_size, (batch_size, input_len)).astype(np.int64) + input_ids_len = np.full((batch_size, 1), input_len, dtype=np.int64) + + # token_ids_all: [batch, max_model_len] — prompt + generated + token_ids_all = np.zeros((batch_size, max_model_len), dtype=np.int64) + prompt_lens = np.full((batch_size, 1), input_len, dtype=np.int64) + step_idx = np.zeros((batch_size, 1), dtype=np.int64) + draft_token_num = np.full((batch_size, 1), max_draft, dtype=np.int32) + draft_tokens = np.zeros((batch_size, max_draft + 1), dtype=np.int64) + seq_lens_this_time = np.ones(batch_size, dtype=np.int32) + seq_lens_encoder = np.zeros(batch_size, dtype=np.int32) + seq_lens_decoder = np.ones(batch_size, dtype=np.int32) + max_dec_len = np.full((batch_size, 1), 200, dtype=np.int64) + + for b in range(batch_size): + # Copy prompt into token_ids_all + token_ids_all[b, :input_len] = input_ids[b] + # Simulate generated tokens: copy contiguous blocks from prompt + # to guarantee ngram matches exist + gen_len = 20 + src = rng.randint(0, max(1, input_len - gen_len)) + token_ids_all[b, input_len : input_len + gen_len] = input_ids[b, src : src + gen_len] + # step_idx = last valid position (0-based index) + step_idx[b] = gen_len - 1 + + return { + "input_ids": input_ids, + "input_ids_len": input_ids_len, + "token_ids_all": token_ids_all, + "prompt_lens": prompt_lens, + "step_idx": step_idx, + "draft_token_num": draft_token_num, + "draft_tokens": draft_tokens, + "seq_lens_this_time": seq_lens_this_time, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_decoder": seq_lens_decoder, + "max_dec_len": max_dec_len, + } + + +def _make_mixed_test_data(batch_size=4, input_len=64, pre_ids_len=256, max_draft=10, seed=42): + """Create realistic test tensors for hybrid_mtp_ngram op.""" + rng = np.random.RandomState(seed) + vocab_size = 1000 + + input_ids = rng.randint(0, vocab_size, (batch_size, input_len)).astype(np.int64) + input_ids_len = np.full((batch_size, 1), input_len, dtype=np.int64) + + pre_ids = np.zeros((batch_size, pre_ids_len), dtype=np.int64) + step_idx = np.zeros((batch_size, 1), dtype=np.int64) + draft_token_num = np.full((batch_size, 1), max_draft, dtype=np.int32) + draft_tokens = np.zeros((batch_size, max_draft + 1), dtype=np.int64) + # For mixed: seq_lens_this_time starts at 1 (already has 1 draft token) + seq_lens_this_time = np.ones(batch_size, dtype=np.int32) + seq_lens_decoder = np.ones(batch_size, dtype=np.int32) + max_dec_len = np.full((batch_size, 1), 200, dtype=np.int64) + + for b in range(batch_size): + # Copy contiguous blocks from prompt to guarantee ngram matches + gen_len = 20 + src = rng.randint(0, max(1, input_len - gen_len)) + pre_ids[b, :gen_len] = input_ids[b, src : src + gen_len] + # step_idx = last valid position (0-based index) + step_idx[b] = gen_len - 1 + + return { + "input_ids": input_ids, + "input_ids_len": input_ids_len, + "pre_ids": pre_ids, + "step_idx": step_idx, + "draft_token_num": draft_token_num, + "draft_tokens": draft_tokens, + "seq_lens_this_time": seq_lens_this_time, + "seq_lens_decoder": seq_lens_decoder, + "max_dec_len": max_dec_len, + } + + +def _to_gpu(np_dict): + """Convert numpy dict to GPU paddle tensors.""" + out = {} + for k, v in np_dict.items(): + out[k] = paddle.to_tensor(v, place=paddle.CUDAPlace(0)) + return out + + +class TestNgramMatchKernel(unittest.TestCase): + """Test ngram_match GPU kernel correctness against CPU reference.""" + + @classmethod + def setUpClass(cls): + if not paddle.is_compiled_with_cuda(): + raise unittest.SkipTest("CUDA not available") + paddle.set_device("gpu") + # Import GPU ops (requires FastDeploy build) + try: + from fastdeploy.model_executor.ops.gpu import ngram_match + + cls.ngram_match = staticmethod(ngram_match) + except Exception as e: + raise unittest.SkipTest(f"Cannot import ngram_match op: {e}") + + def test_correctness_basic(self): + """Basic correctness: GPU output matches CPU reference.""" + data = _make_ngram_test_data(batch_size=4, seed=42) + max_ngram_size = 3 + max_draft_tokens = 10 + + # CPU reference + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_ngram_match( + data["input_ids"], + data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_encoder"], + data["seq_lens_decoder"], + data["max_dec_len"], + max_ngram_size, + max_draft_tokens, + ) + + # GPU kernel + gpu_data = _to_gpu(data) + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + max_ngram_size, + max_draft_tokens, + ) + paddle.device.synchronize() + + gpu_draft = gpu_data["draft_tokens"].numpy() + gpu_slt = gpu_data["seq_lens_this_time"].numpy() + + np.testing.assert_array_equal(gpu_slt, cpu_slt, err_msg="seq_lens_this_time mismatch") + np.testing.assert_array_equal(gpu_draft, cpu_draft, err_msg="draft_tokens mismatch") + + def test_correctness_varied_seeds(self): + """Test across multiple random seeds.""" + for seed in [0, 7, 123, 999]: + with self.subTest(seed=seed): + data = _make_ngram_test_data(batch_size=8, seed=seed) + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_ngram_match( + data["input_ids"], + data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_encoder"], + data["seq_lens_decoder"], + data["max_dec_len"], + 3, + 10, + ) + gpu_data = _to_gpu(data) + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft) + + def test_large_batch_long_seq(self): + """bsz=256, seq_len=128k — scale the reviewer demanded. + + Uses high threshold to ensure all batches exercise the parallel search + path (default threshold=128 would skip all batches at bsz=256). + """ + high_threshold = 100000 + data = _make_ngram_test_data(batch_size=256, input_len=131072, max_model_len=131072 + 64, seed=77) + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_ngram_match( + data["input_ids"], + data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_encoder"], + data["seq_lens_decoder"], + data["max_dec_len"], + 3, + 10, + threshold=high_threshold, + ) + gpu_data = _to_gpu(data) + old_env = os.environ.get("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD") + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = str(high_threshold) + try: + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + finally: + if old_env is None: + os.environ.pop("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD", None) + else: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = old_env + np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft) + + def test_single_batch_long_seq(self): + """bsz=1, seq_len=128k — single long sequence.""" + data = _make_ngram_test_data(batch_size=1, input_len=131072, max_model_len=131072 + 64, seed=88) + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_ngram_match( + data["input_ids"], + data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_encoder"], + data["seq_lens_decoder"], + data["max_dec_len"], + 3, + 10, + ) + gpu_data = _to_gpu(data) + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft) + + def test_many_short_seqs(self): + """bsz=256, seq_len=1k — many short sequences.""" + high_threshold = 100000 + data = _make_ngram_test_data(batch_size=256, input_len=1024, seed=55) + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_ngram_match( + data["input_ids"], + data["input_ids_len"], + data["token_ids_all"], + data["prompt_lens"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_encoder"], + data["seq_lens_decoder"], + data["max_dec_len"], + 3, + 10, + threshold=high_threshold, + ) + gpu_data = _to_gpu(data) + old_env = os.environ.get("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD") + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = str(high_threshold) + try: + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + finally: + if old_env is None: + os.environ.pop("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD", None) + else: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = old_env + np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft) + + def test_latency(self): + """Benchmark: GPU kernel latency vs CPU transfer overhead.""" + # Warmup + for _ in range(5): + d = _to_gpu(_make_ngram_test_data(batch_size=32, input_len=512, seed=42)) + self.ngram_match( + d["input_ids"], + d["input_ids_len"], + d["token_ids_all"], + d["prompt_lens"], + d["step_idx"], + d["draft_token_num"], + d["draft_tokens"], + d["seq_lens_this_time"], + d["seq_lens_encoder"], + d["seq_lens_decoder"], + d["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + + # GPU path: kernel execution only (pre-created tensors, no data transfer) + gpu_data = _to_gpu(_make_ngram_test_data(batch_size=32, input_len=512, seed=42)) + cpu_data = _make_ngram_test_data(batch_size=32, input_len=512, seed=42) + n_runs = 100 + paddle.device.synchronize() + t0 = time.perf_counter() + for _ in range(n_runs): + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + t1 = time.perf_counter() + gpu_time_ms = (t1 - t0) / n_runs * 1000 + + # CPU path: simulate the old copy-to-CPU-and-back pattern + paddle.device.synchronize() + t0 = time.perf_counter() + for _ in range(n_runs): + # Simulate old path: copy all tensors to CPU then back + cpu_tensors = {k: paddle.to_tensor(v, place=paddle.CPUPlace()) for k, v in cpu_data.items()} + _ = cpu_tensors["draft_tokens"].cuda() + _ = cpu_tensors["seq_lens_this_time"].cuda() + paddle.device.synchronize() + t1 = time.perf_counter() + cpu_copy_time_ms = (t1 - t0) / n_runs * 1000 + + print(f"\n{'='*60}") + print(f"LATENCY BENCHMARK (batch=32, input_len=512, {n_runs} runs)") + print(f" GPU kernel (zero-copy): {gpu_time_ms:.3f} ms/call") + print(f" CPU path (copy overhead): {cpu_copy_time_ms:.3f} ms/call") + print(f" Speedup: {cpu_copy_time_ms / gpu_time_ms:.2f}x") + print(f"{'='*60}") + + def test_latency_scaling(self): + """Benchmark GPU kernel across batch sizes to show Phase 2 scales.""" + batch_sizes = [32, 128, 256, 512, 1024] + input_len = 512 + n_runs = 50 + results = [] + + for bsz in batch_sizes: + # Pre-create tensors once per batch size (excluded from timing) + gpu_data = _to_gpu(_make_ngram_test_data(batch_size=bsz, input_len=input_len, seed=42)) + cpu_data = _make_ngram_test_data(batch_size=bsz, input_len=input_len, seed=42) + + # Warmup + for _ in range(3): + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + + # GPU kernel (pure kernel time — no data creation/transfer) + paddle.device.synchronize() + t0 = time.perf_counter() + for _ in range(n_runs): + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + gpu_ms = (time.perf_counter() - t0) / n_runs * 1000 + + # CPU path: simulate the old copy-to-CPU-and-back pattern + paddle.device.synchronize() + t0 = time.perf_counter() + for _ in range(n_runs): + cpu_tensors = {k: paddle.to_tensor(v, place=paddle.CPUPlace()) for k, v in cpu_data.items()} + _ = cpu_tensors["draft_tokens"].cuda() + _ = cpu_tensors["seq_lens_this_time"].cuda() + paddle.device.synchronize() + cpu_ms = (time.perf_counter() - t0) / n_runs * 1000 + + results.append((bsz, gpu_ms, cpu_ms)) + + print(f"\n{'='*72}") + print(f"SCALING BENCHMARK (input_len={input_len}, {n_runs} runs per config)") + print(f"{'─'*72}") + print(f"{'batch':>6} {'GPU (ms)':>10} {'CPU (ms)':>10} {'Speedup':>8} {'GPU/batch(µs)':>14}") + print(f"{'─'*72}") + for bsz, gpu_ms, cpu_ms in results: + speedup = cpu_ms / gpu_ms + per_batch_us = gpu_ms / bsz * 1000 + print(f"{bsz:>6} {gpu_ms:>10.3f} {cpu_ms:>10.3f} {speedup:>7.2f}x {per_batch_us:>14.3f}") + print(f"{'='*72}") + + def test_latency_extreme(self): + """Benchmark: GPU kernel at extreme scale (bsz=256, seq_len=128k). + + Addresses the NCU profiler worst-case scenario (bsz=256 + 128k) + raised in review. Tests with production-realistic thresholds + (8192, 16384) rather than the unlimited threshold used in + correctness tests. + """ + configs = [ + {"threshold": 8192, "label": "threshold=8192"}, + {"threshold": 16384, "label": "threshold=16384"}, + ] + batch_size = 256 + input_len = 131072 # 128k + n_runs = 1000 + + # Pre-create tensors once (excluded from timing) + gpu_data = _to_gpu( + _make_ngram_test_data( + batch_size=batch_size, + input_len=input_len, + max_model_len=input_len + 64, + seed=77, + ) + ) + cpu_data = _make_ngram_test_data( + batch_size=batch_size, + input_len=input_len, + max_model_len=input_len + 64, + seed=77, + ) + + print(f"\n{'='*72}") + print(f"EXTREME BENCHMARK (batch={batch_size}, seq_len={input_len}, {n_runs} runs)") + print(f"{'─'*72}") + + for cfg in configs: + threshold = cfg["threshold"] + old_env = os.environ.get("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD") + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = str(threshold) + try: + # Warmup + for _ in range(3): + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + + # GPU kernel timing + paddle.device.synchronize() + t0 = time.perf_counter() + for _ in range(n_runs): + self.ngram_match( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["token_ids_all"], + gpu_data["prompt_lens"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_encoder"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 10, + ) + paddle.device.synchronize() + t1 = time.perf_counter() + gpu_ms = (t1 - t0) / n_runs * 1000 + finally: + if old_env is None: + os.environ.pop("INFER_WITH_REFERENCE_TOKENUM_THRESHOLD", None) + else: + os.environ["INFER_WITH_REFERENCE_TOKENUM_THRESHOLD"] = old_env + + # CPU path: simulate copy-to-CPU-and-back overhead at extreme scale + cpu_runs = 50 # fewer runs — CPU copy of 256x128k is slow + paddle.device.synchronize() + t0 = time.perf_counter() + for _ in range(cpu_runs): + cpu_tensors = {k: paddle.to_tensor(v, place=paddle.CPUPlace()) for k, v in cpu_data.items()} + _ = cpu_tensors["draft_tokens"].cuda() + _ = cpu_tensors["seq_lens_this_time"].cuda() + paddle.device.synchronize() + t1 = time.perf_counter() + cpu_ms = (t1 - t0) / cpu_runs * 1000 + + speedup = cpu_ms / gpu_ms if gpu_ms > 0 else float("inf") + print(f" [{cfg['label']}]") + print(f" GPU kernel: {gpu_ms:.3f} ms/call ({gpu_ms * 1000:.1f} us)") + print(f" CPU path: {cpu_ms:.3f} ms/call") + print(f" Speedup: {speedup:.1f}x") + print() + + print(f"{'='*72}") + + +class TestHybridMtpNgramKernel(unittest.TestCase): + """Test hybrid_mtp_ngram GPU kernel correctness against CPU reference.""" + + @classmethod + def setUpClass(cls): + if not paddle.is_compiled_with_cuda(): + raise unittest.SkipTest("CUDA not available") + paddle.set_device("gpu") + try: + from fastdeploy.model_executor.ops.gpu import hybrid_mtp_ngram + + cls.hybrid_mtp_ngram = staticmethod(hybrid_mtp_ngram) + except Exception as e: + raise unittest.SkipTest(f"Cannot import hybrid_mtp_ngram op: {e}") + + def test_correctness_basic(self): + """Basic correctness: GPU output matches CPU reference.""" + data = _make_mixed_test_data(batch_size=4, seed=42) + max_ngram_size = 3 + min_ngram_size = 1 + max_draft_tokens = 10 + + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_hybrid_mtp_ngram( + data["input_ids"], + data["input_ids_len"], + data["pre_ids"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_decoder"], + data["max_dec_len"], + max_ngram_size, + min_ngram_size, + max_draft_tokens, + ) + + gpu_data = _to_gpu(data) + self.hybrid_mtp_ngram( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["pre_ids"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + max_ngram_size, + min_ngram_size, + max_draft_tokens, + ) + paddle.device.synchronize() + + np.testing.assert_array_equal( + gpu_data["seq_lens_this_time"].numpy(), cpu_slt, err_msg="seq_lens_this_time mismatch" + ) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft, err_msg="draft_tokens mismatch") + + def test_correctness_varied_seeds(self): + """Test across multiple random seeds.""" + for seed in [0, 7, 123, 999]: + with self.subTest(seed=seed): + data = _make_mixed_test_data(batch_size=8, seed=seed) + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_hybrid_mtp_ngram( + data["input_ids"], + data["input_ids_len"], + data["pre_ids"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_decoder"], + data["max_dec_len"], + 3, + 1, + 10, + ) + gpu_data = _to_gpu(data) + self.hybrid_mtp_ngram( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["pre_ids"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 1, + 10, + ) + paddle.device.synchronize() + np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft) + + def test_large_batch_long_seq(self): + """bsz=256, seq_len=128k — scale the reviewer demanded. + + Uses high threshold to ensure all batches exercise the parallel search + path (default threshold=1024 would skip many batches at bsz=256). + """ + high_threshold = 100000 + data = _make_mixed_test_data(batch_size=256, input_len=131072, pre_ids_len=131072 + 64, seed=77) + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_hybrid_mtp_ngram( + data["input_ids"], + data["input_ids_len"], + data["pre_ids"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_decoder"], + data["max_dec_len"], + 3, + 1, + 10, + threshold=high_threshold, + ) + gpu_data = _to_gpu(data) + old_env = os.environ.get("SPEC_TOKENUM_THRESHOLD") + os.environ["SPEC_TOKENUM_THRESHOLD"] = str(high_threshold) + try: + self.hybrid_mtp_ngram( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["pre_ids"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 1, + 10, + ) + paddle.device.synchronize() + finally: + if old_env is None: + os.environ.pop("SPEC_TOKENUM_THRESHOLD", None) + else: + os.environ["SPEC_TOKENUM_THRESHOLD"] = old_env + np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft) + + def test_single_batch_long_seq(self): + """bsz=1, seq_len=128k — single long sequence.""" + data = _make_mixed_test_data(batch_size=1, input_len=131072, pre_ids_len=131072 + 64, seed=88) + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_hybrid_mtp_ngram( + data["input_ids"], + data["input_ids_len"], + data["pre_ids"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_decoder"], + data["max_dec_len"], + 3, + 1, + 10, + ) + gpu_data = _to_gpu(data) + self.hybrid_mtp_ngram( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["pre_ids"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 1, + 10, + ) + paddle.device.synchronize() + np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft) + + def test_many_short_seqs(self): + """bsz=256, seq_len=1k — many short sequences.""" + high_threshold = 100000 + data = _make_mixed_test_data(batch_size=256, input_len=1024, seed=55) + cpu_draft = data["draft_tokens"].copy() + cpu_slt = data["seq_lens_this_time"].copy() + _cpu_hybrid_mtp_ngram( + data["input_ids"], + data["input_ids_len"], + data["pre_ids"], + data["step_idx"], + data["draft_token_num"], + cpu_draft, + cpu_slt, + data["seq_lens_decoder"], + data["max_dec_len"], + 3, + 1, + 10, + threshold=high_threshold, + ) + gpu_data = _to_gpu(data) + old_env = os.environ.get("SPEC_TOKENUM_THRESHOLD") + os.environ["SPEC_TOKENUM_THRESHOLD"] = str(high_threshold) + try: + self.hybrid_mtp_ngram( + gpu_data["input_ids"], + gpu_data["input_ids_len"], + gpu_data["pre_ids"], + gpu_data["step_idx"], + gpu_data["draft_token_num"], + gpu_data["draft_tokens"], + gpu_data["seq_lens_this_time"], + gpu_data["seq_lens_decoder"], + gpu_data["max_dec_len"], + 3, + 1, + 10, + ) + paddle.device.synchronize() + finally: + if old_env is None: + os.environ.pop("SPEC_TOKENUM_THRESHOLD", None) + else: + os.environ["SPEC_TOKENUM_THRESHOLD"] = old_env + np.testing.assert_array_equal(gpu_data["seq_lens_this_time"].numpy(), cpu_slt) + np.testing.assert_array_equal(gpu_data["draft_tokens"].numpy(), cpu_draft) + + +if __name__ == "__main__": + unittest.main(verbosity=2)