Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
daf20d9
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 6, 2026
6f1e63c
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 6, 2026
4deb7a7
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 9, 2026
676daf6
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 9, 2026
9bcfdca
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 10, 2026
2bfa878
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 10, 2026
262c470
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 11, 2026
171b4d3
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 17, 2026
def0bd2
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 19, 2026
4fad5dc
Merge remote-tracking branch 'upstream/develop' into develop
cloudforge1 Mar 20, 2026
3d739a6
Port ngram_match and hybrid_mtp_ngram kernels to CUDA
cloudforge1 Mar 20, 2026
477f749
Add correctness + latency test for GPU ngram kernels
cloudforge1 Mar 20, 2026
c349b12
Fix test data: step_idx semantics and ngram-matchable patterns
cloudforge1 Mar 20, 2026
217e587
fix: add CPU fallback path for ngram_match and hybrid_mtp_ngram ops
cloudforge1 Mar 21, 2026
08fe00a
fix(test): wrap imported ops with staticmethod to prevent self-binding
cloudforge1 Mar 21, 2026
305868d
fix(test): ensure max_model_len >= input_len to prevent broadcast err…
cloudforge1 Mar 21, 2026
1dfaed5
fix: keep input_ids_len on CPU in __init__, move to GPU in _run_impl
cloudforge1 Mar 22, 2026
b7f1f38
Extract shared ngram search into __device__ helper (ngram_match_commo…
cloudforge1 Mar 25, 2026
3f71877
refactor: parallel CUDA kernels for ngram_match (<<<bsz,256>>> search)
cloudforge1 Mar 30, 2026
838d6dc
fix: move __global__ kernel defs from .cuh to .cu files (fix linker m…
cloudforge1 Mar 30, 2026
f45e39b
fix: align mixed kernel signatures with host function tensors
cloudforge1 Mar 30, 2026
a7f149a
fix: address review — GPU mirror for input_ids_len, device mismatch i…
cloudforge1 Apr 2, 2026
65f609b
bench: add 5-group benchmark matching NKNaN methodology
cloudforge1 Apr 2, 2026
c6e698f
fix: rename benchmark for CI discovery, bump to 10k iterations
cloudforge1 Apr 2, 2026
d5e2e9a
Gate expensive ngram tests behind env vars
cloudforge1 Apr 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
307 changes: 252 additions & 55 deletions custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -12,44 +12,193 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>
#include <vector>
#include <string>
#include <algorithm>
#include <chrono>
#include <cstdlib>
#include <cstring>
#include <string>
#include "paddle/extension.h"
#include "../ngram_match_common.cuh"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

int sum_mixed(const int *value, int num) {
// ============================================================
// Phase 1 mixed search kernel — one block per batch item
// ============================================================
__global__ void ngram_match_mixed_search_kernel(
const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int32_t *seq_lens_this_time,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t max_batch_size,
int max_ngram_size,
int min_ngram_size,
NgramMatchResult *match_results) {
int batch_idx = blockIdx.x;
if (batch_idx >= max_batch_size) return;

__shared__ int64_t s_min_pos;

if (threadIdx.x == 0) {
match_results[batch_idx].match_pos = -1;
match_results[batch_idx].ngram_size = 0;
match_results[batch_idx].haystack_type = 0;
}
__syncthreads();

// Skip batch items with no active tokens (matches CPU path logic)
if (seq_lens_this_time[batch_idx] == 0) return;

const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
const int64_t *cur_pre_ids = pre_ids + batch_idx * pre_ids_stride;
const int64_t cur_step_idx = step_idx[batch_idx];

for (int ngram_size = max_ngram_size; ngram_size >= min_ngram_size;
--ngram_size) {
if (cur_step_idx < ngram_size) continue;

const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);

int64_t pos = parallel_ngram_search(
cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos);
if (pos != INT64_MAX) {
if (threadIdx.x == 0) {
match_results[batch_idx].match_pos = pos;
match_results[batch_idx].ngram_size = ngram_size;
match_results[batch_idx].haystack_type = 0;
}
return;
}

pos = parallel_ngram_search(
cur_pre_ids, cur_step_idx, ngram, ngram_size, &s_min_pos);
if (pos != INT64_MAX) {
if (threadIdx.x == 0) {
match_results[batch_idx].match_pos = pos;
match_results[batch_idx].ngram_size = ngram_size;
match_results[batch_idx].haystack_type = 1;
}
return;
}
}
}

// ============================================================
// Phase 2 mixed gather kernel — serial threshold + copy
// ============================================================
__global__ void ngram_match_mixed_gather_kernel(
const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
const int32_t *seq_lens_decoder,
const int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_draft_tokens_param,
int threshold,
const NgramMatchResult *match_results) {
int unprocessed_batch_size = 0;
for (int i = 0; i < max_batch_size; i++) {
if (seq_lens_decoder[i] > 0) {
unprocessed_batch_size++;
}
}

for (int batch_idx = 0; batch_idx < max_batch_size; batch_idx++) {
const int ori_seq_len_this_time = seq_lens_this_time[batch_idx];
int max_draft_tokens =
static_cast<int>(min(static_cast<int64_t>(max_draft_tokens_param -
ori_seq_len_this_time + 1),
max_dec_len[batch_idx] - step_idx[batch_idx] - 1));

if (ori_seq_len_this_time == 0 || max_draft_tokens <= 0) {
continue;
}

unprocessed_batch_size--;
int sum_token_num = 0;
for (int i = 0; i <= batch_idx; i++) {
sum_token_num += seq_lens_this_time[i];
}
int left_min_token_num = unprocessed_batch_size;

if (sum_token_num + max_draft_tokens + left_min_token_num > threshold) {
int tmp = threshold - sum_token_num - left_min_token_num;
max_draft_tokens = min(tmp, max_draft_tokens);
}

if (sum_token_num + left_min_token_num >= threshold - 1) {
continue;
}

const NgramMatchResult &res = match_results[batch_idx];
if (res.match_pos < 0) continue;

const int64_t *haystack;
int64_t haystack_len;
if (res.haystack_type == 0) {
haystack = input_ids + batch_idx * input_ids_stride;
haystack_len = input_ids_len[batch_idx];
} else {
haystack = pre_ids + batch_idx * pre_ids_stride;
haystack_len = step_idx[batch_idx];
}

int64_t start_idx = res.match_pos + res.ngram_size;
int64_t end_idx =
min(start_idx + static_cast<int64_t>(max_draft_tokens), haystack_len);
if (start_idx >= end_idx) continue;

int64_t n = end_idx - start_idx;
seq_lens_this_time[batch_idx] =
static_cast<int32_t>(ori_seq_len_this_time + n);
int64_t *cur_draft = draft_tokens + batch_idx * draft_tokens_stride;
for (int64_t k = 0; k < n; k++) {
cur_draft[ori_seq_len_this_time + k] = haystack[start_idx + k];
}
}
}

// ============================================================
// CPU path — preserved from original for backward compatibility
// with CPU-only callers and tests.
// ============================================================
static int sum_mixed_cpu(const int *value, int num) {
int sum_value = 0;
for (int i = 0; i <= num; i++) {
sum_value += value[i];
}
return sum_value;
}

void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size = 3,
int min_ngram_size = 1,
const int max_draft_tokens = 10) {
static void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t *input_ids_len,
const int64_t *pre_ids,
const int64_t *step_idx,
const int *draft_token_num,
int64_t *draft_tokens,
int32_t *seq_lens_this_time,
int32_t *seq_lens_decoder,
int64_t *max_dec_len,
int64_t input_ids_stride,
int64_t pre_ids_stride,
int64_t draft_tokens_stride,
int64_t max_batch_size,
int max_ngram_size = 3,
int min_ngram_size = 1,
const int max_draft_tokens = 10) {
int threshold = 1024;
// dynamic in future
char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
Expand Down Expand Up @@ -77,7 +226,7 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
unprocessed_batch_size--;

auto sum_token_num = sum_mixed(seq_lens_this_time, batch_idx);
auto sum_token_num = sum_mixed_cpu(seq_lens_this_time, batch_idx);
int left_min_token_num = unprocessed_batch_size;

if (sum_token_num + max_draft_tokens_query + left_min_token_num >
Expand All @@ -91,21 +240,16 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
continue;
}
bool match_global = false;
// apply ngram_match in input_ids
for (int ngram_size = max_ngram_size;
ngram_size >= min_ngram_size && !match_global;
--ngram_size) {
// Extract the last n tokens as our search ngram
if (cur_step_idx < ngram_size) {
continue;
}
const int64_t *ngram = cur_pre_ids + (cur_step_idx + 1 - ngram_size);

// Iterate through sliding windows of size ngram_size
// bool match_input = false;
for (int64_t i = 0; i <= cur_input_ids_len - ngram_size && !match_global;
++i) {
// Check if the current window matches the ngram
bool match_local = true;
for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_input_ids[i + j]) {
Expand All @@ -120,24 +264,19 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
if (start_idx >= end_idx) continue;

int64_t cur_draft_token_num = end_idx - start_idx;

seq_lens_this_time[batch_idx] =
ori_seq_len_this_time + cur_draft_token_num;
memcpy(cur_draft_tokens + ori_seq_len_this_time,
cur_input_ids + start_idx,
sizeof(int64_t) * cur_draft_token_num);
// To break the current batch_idx for-loop
match_global = true;
break;
}
}
// apply ngram_match in generated tokens
if (!match_global) {
for (int64_t i = 0; i <= cur_step_idx - ngram_size && !match_global;
++i) {
// Check if the current window matches the ngram
bool match_local = true;

for (int j = 0; j < ngram_size; j++) {
if (ngram[j] != cur_pre_ids[i + j]) {
match_local = false;
Expand All @@ -148,13 +287,8 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
int64_t start_idx = i + ngram_size;
int64_t end_idx =
std::min(start_idx + max_draft_tokens_query, cur_step_idx);

int64_t cur_draft_token_num = end_idx - start_idx;

if (start_idx >= end_idx) continue;
// printf("match in Output with Ngram_size %d.
// %lld:[%lld,%lld]\n",ngram_size, cur_draft_token_num, start_idx,
// end_idx);

seq_lens_this_time[batch_idx] =
ori_seq_len_this_time + cur_draft_token_num;
Expand All @@ -170,6 +304,15 @@ void find_candidate_pred_tokens_mixed(const int64_t *input_ids,
}
}

// ============================================================
// GPU path — Two-phase parallel CUDA kernels for hybrid ngram matching.
//
// Phase 1: <<<bsz, NGRAM_BLOCK_THREADS>>> — parallel sliding-window
// search within each batch item (256 threads per batch).
// Phase 2: <<<1, 1>>> — serial threshold + token copy (inter-batch
// dependency via running sum of seq_lens_this_time).
// ============================================================

void HybridMtpNgram(const paddle::Tensor &input_ids,
const paddle::Tensor &input_ids_len,
const paddle::Tensor &pre_ids,
Expand All @@ -193,23 +336,77 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,

const int64_t max_batch_size = seq_lens_this_time.shape()[0];

find_candidate_pred_tokens_mixed(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
int threshold = 1024;
const char *env_var = getenv("SPEC_TOKENUM_THRESHOLD");
if (env_var) {
threshold = std::stoi(env_var);
Comment on lines +341 to +342
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里同样使用 std::stoi(env_var) 解析 SPEC_TOKENUM_THRESHOLD;一旦环境变量不是纯数字,会抛异常并可能导致进程异常退出。建议使用不抛异常的解析方式(如 std::from_chars/strtol)并在失败时回退默认值,同时做阈值范围校验。

Suggested change
if (env_var) {
threshold = std::stoi(env_var);
if (env_var && env_var[0] != '\0') {
char *end_ptr = nullptr;
long parsed_threshold = strtol(env_var, &end_ptr, 10);
if (end_ptr != env_var && *end_ptr == '\0' && parsed_threshold > 0 &&
parsed_threshold <= static_cast<long>(std::numeric_limits<int>::max())) {
threshold = static_cast<int>(parsed_threshold);
}

Copilot uses AI. Check for mistakes.
}

if (input_ids.is_gpu()) {
auto stream = input_ids.stream();

// Allocate scratch buffer for Phase 1 → Phase 2 communication
auto match_buf = paddle::empty(
{max_batch_size * static_cast<int64_t>(sizeof(NgramMatchResult))},
paddle::DataType::UINT8,
input_ids.place());
auto *match_results =
reinterpret_cast<NgramMatchResult *>(match_buf.data<uint8_t>());

// Phase 1: parallel search — one block per batch, 256 threads per block
ngram_match_mixed_search_kernel<<<max_batch_size,
NGRAM_BLOCK_THREADS,
0,
stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
seq_lens_this_time.data<int32_t>(),
input_ids_stride,
pre_ids_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
match_results);

// Phase 2: serial threshold + token copy (same stream = ordered)
ngram_match_mixed_gather_kernel<<<1, 1, 0, stream>>>(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
seq_lens_decoder.data<int32_t>(),
max_dec_len.data<int64_t>(),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
max_draft_tokens,
threshold,
match_results);
} else {
find_candidate_pred_tokens_mixed(
input_ids.data<int64_t>(),
input_ids_len.data<int64_t>(),
pre_ids.data<int64_t>(),
step_idx.data<int64_t>(),
draft_token_num.data<int>(),
const_cast<int64_t *>(draft_tokens.data<int64_t>()),
const_cast<int32_t *>(seq_lens_this_time.data<int32_t>()),
const_cast<int32_t *>(seq_lens_decoder.data<int32_t>()),
const_cast<int64_t *>(max_dec_len.data<int64_t>()),
input_ids_stride,
pre_ids_stride,
draft_tokens_stride,
max_batch_size,
max_ngram_size,
min_ngram_size,
max_draft_tokens);
}
}

PD_BUILD_STATIC_OP(hybrid_mtp_ngram)
Expand Down
Loading
Loading