Skip to content

Commit 00a6d4c

Browse files
author
cloudforge1
committed
fix: address Copilot review — conditional return, defensive guards, GPU placement
- ngram_match.cu: add remaining<=0 early return, conditional return only when tokens produced (matches CPU continue behavior), include encoder-active items in Phase 2 threshold-budget scan - ngram_match_mixed.cu: split max_draft_tokens into explicit steps to prevent negative intermediates, conditional return only when tokens produced, add seq_lens_decoder invariant comment - ngram.py: explicit .cuda() on input_ids_len_gpu creation - test_ngram_gpu_kernel.py: use CPUPlace() in latency benchmark to measure actual D2H/H2D roundtrip
1 parent 04346f8 commit 00a6d4c

File tree

4 files changed

+83
-62
lines changed

4 files changed

+83
-62
lines changed

custom_ops/gpu_ops/speculate_decoding/draft_model/ngram_match_mixed.cu

Lines changed: 42 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,13 @@ __global__ void ngram_match_mixed_search_kernel(
6161
// Skip batch items with no active tokens
6262
if (ori_seq_len_this_time == 0) return;
6363

64-
// Compute max_draft_tokens for this batch item
65-
int max_draft_tokens = static_cast<int>(min(
66-
static_cast<int64_t>(max_draft_tokens_param - ori_seq_len_this_time + 1),
67-
max_dec_len[batch_idx] - step_idx[batch_idx] - 1));
68-
if (max_draft_tokens <= 0) return;
64+
// Compute max_draft_tokens for this batch item.
65+
// Split into explicit steps to avoid negative intermediate values.
66+
int64_t draft_budget =
67+
static_cast<int64_t>(max_draft_tokens_param) - ori_seq_len_this_time + 1;
68+
int64_t remaining_dec = max_dec_len[batch_idx] - step_idx[batch_idx] - 1;
69+
if (draft_budget <= 0 || remaining_dec <= 0) return;
70+
int max_draft_tokens = static_cast<int>(min(draft_budget, remaining_dec));
6971

7072
const int64_t *cur_input_ids = input_ids + batch_idx * input_ids_stride;
7173
const int64_t cur_input_ids_len = input_ids_len[batch_idx];
@@ -81,44 +83,45 @@ __global__ void ngram_match_mixed_search_kernel(
8183
int64_t pos = parallel_ngram_search(
8284
cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos);
8385
if (pos != INT64_MAX) {
84-
if (threadIdx.x == 0) {
86+
int64_t start_idx = pos + ngram_size;
87+
int64_t end_idx = min(start_idx + static_cast<int64_t>(max_draft_tokens),
88+
cur_input_ids_len);
89+
if (threadIdx.x == 0 && start_idx < end_idx) {
8590
// Tentative token copy to scratch
86-
int64_t start_idx = pos + ngram_size;
87-
int64_t end_idx =
88-
min(start_idx + static_cast<int64_t>(max_draft_tokens),
89-
cur_input_ids_len);
90-
if (start_idx < end_idx) {
91-
int64_t n = end_idx - start_idx;
92-
seq_lens_this_time_copy[batch_idx] =
93-
static_cast<int32_t>(ori_seq_len_this_time + n);
94-
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
95-
for (int64_t k = 0; k < n; k++) {
96-
dst[ori_seq_len_this_time + k] = cur_input_ids[start_idx + k];
97-
}
91+
int64_t n = end_idx - start_idx;
92+
seq_lens_this_time_copy[batch_idx] =
93+
static_cast<int32_t>(ori_seq_len_this_time + n);
94+
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
95+
for (int64_t k = 0; k < n; k++) {
96+
dst[ori_seq_len_this_time + k] = cur_input_ids[start_idx + k];
9897
}
9998
}
100-
return;
99+
// Only early-exit when tokens were actually produced
100+
if (start_idx < end_idx) {
101+
return;
102+
}
101103
}
102104

103105
pos = parallel_ngram_search(
104106
cur_pre_ids, cur_step_idx, ngram, ngram_size, &s_min_pos);
105107
if (pos != INT64_MAX) {
106-
if (threadIdx.x == 0) {
108+
int64_t start_idx = pos + ngram_size;
109+
int64_t end_idx =
110+
min(start_idx + static_cast<int64_t>(max_draft_tokens), cur_step_idx);
111+
if (threadIdx.x == 0 && start_idx < end_idx) {
107112
// Tentative token copy to scratch
108-
int64_t start_idx = pos + ngram_size;
109-
int64_t end_idx = min(
110-
start_idx + static_cast<int64_t>(max_draft_tokens), cur_step_idx);
111-
if (start_idx < end_idx) {
112-
int64_t n = end_idx - start_idx;
113-
seq_lens_this_time_copy[batch_idx] =
114-
static_cast<int32_t>(ori_seq_len_this_time + n);
115-
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
116-
for (int64_t k = 0; k < n; k++) {
117-
dst[ori_seq_len_this_time + k] = cur_pre_ids[start_idx + k];
118-
}
113+
int64_t n = end_idx - start_idx;
114+
seq_lens_this_time_copy[batch_idx] =
115+
static_cast<int32_t>(ori_seq_len_this_time + n);
116+
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
117+
for (int64_t k = 0; k < n; k++) {
118+
dst[ori_seq_len_this_time + k] = cur_pre_ids[start_idx + k];
119119
}
120120
}
121-
return;
121+
// Only early-exit when tokens were actually produced
122+
if (start_idx < end_idx) {
123+
return;
124+
}
122125
}
123126
}
124127
}
@@ -389,6 +392,13 @@ void HybridMtpNgram(const paddle::Tensor &input_ids,
389392
if (input_ids.is_gpu()) {
390393
auto stream = input_ids.stream();
391394

395+
// NOTE: GPU path does not pass seq_lens_decoder to kernels — the mixed
396+
// variant uses ori_seq_len_this_time == 0 to skip inactive items. This
397+
// matches CPU behavior under the invariant that seq_lens_decoder > 0 iff
398+
// ori_seq_len_this_time > 0 (holds during normal MTP decoding). The CPU
399+
// path counts seq_lens_decoder > 0 for threshold budget; the GPU scan
400+
// counts tentative > 0, which is equivalent under this invariant.
401+
392402
// Allocate scratch buffers for Phase 1 → Phase 2 communication
393403

394404
// Scratch copy of draft_tokens (Phase 1 writes tentative tokens here)

custom_ops/gpu_ops/speculate_decoding/ngram_match.cu

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
7272

7373
// Compute max_draft_tokens for this batch item
7474
int64_t remaining = max_dec_len[batch_idx] - cur_step_idx - 1;
75+
if (remaining <= 0) return;
7576
int max_draft_tokens = static_cast<int>(
7677
min(static_cast<int64_t>(draft_token_num[batch_idx]), remaining));
7778

@@ -83,42 +84,43 @@ __global__ void ngram_match_search_kernel(const int64_t *input_ids,
8384
int64_t pos = parallel_ngram_search(
8485
cur_input_ids, cur_input_ids_len, ngram, ngram_size, &s_min_pos);
8586
if (pos != INT64_MAX) {
86-
if (threadIdx.x == 0) {
87+
int64_t start_idx = pos + ngram_size;
88+
int64_t end_idx = min(start_idx + static_cast<int64_t>(max_draft_tokens),
89+
cur_input_ids_len);
90+
if (threadIdx.x == 0 && start_idx < end_idx) {
8791
// Tentative token copy to scratch
88-
int64_t start_idx = pos + ngram_size;
89-
int64_t end_idx =
90-
min(start_idx + static_cast<int64_t>(max_draft_tokens),
91-
cur_input_ids_len);
92-
if (start_idx < end_idx) {
93-
int64_t n = end_idx - start_idx;
94-
seq_lens_this_time_copy[batch_idx] = static_cast<int32_t>(1 + n);
95-
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
96-
for (int64_t k = 0; k < n; k++) {
97-
dst[1 + k] = cur_input_ids[start_idx + k];
98-
}
92+
int64_t n = end_idx - start_idx;
93+
seq_lens_this_time_copy[batch_idx] = static_cast<int32_t>(1 + n);
94+
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
95+
for (int64_t k = 0; k < n; k++) {
96+
dst[1 + k] = cur_input_ids[start_idx + k];
9997
}
10098
}
101-
return;
99+
// Only early-exit when tokens were actually produced
100+
if (start_idx < end_idx) {
101+
return;
102+
}
102103
}
103104

104105
pos = parallel_ngram_search(
105106
cur_pre_ids, cur_step_idx, ngram, ngram_size, &s_min_pos);
106107
if (pos != INT64_MAX) {
107-
if (threadIdx.x == 0) {
108+
int64_t start_idx = pos + ngram_size;
109+
int64_t end_idx =
110+
min(start_idx + static_cast<int64_t>(max_draft_tokens), cur_step_idx);
111+
if (threadIdx.x == 0 && start_idx < end_idx) {
108112
// Tentative token copy to scratch
109-
int64_t start_idx = pos + ngram_size;
110-
int64_t end_idx = min(
111-
start_idx + static_cast<int64_t>(max_draft_tokens), cur_step_idx);
112-
if (start_idx < end_idx) {
113-
int64_t n = end_idx - start_idx;
114-
seq_lens_this_time_copy[batch_idx] = static_cast<int32_t>(1 + n);
115-
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
116-
for (int64_t k = 0; k < n; k++) {
117-
dst[1 + k] = cur_pre_ids[start_idx + k];
118-
}
113+
int64_t n = end_idx - start_idx;
114+
seq_lens_this_time_copy[batch_idx] = static_cast<int32_t>(1 + n);
115+
int64_t *dst = draft_tokens_copy + batch_idx * draft_tokens_stride;
116+
for (int64_t k = 0; k < n; k++) {
117+
dst[1 + k] = cur_pre_ids[start_idx + k];
119118
}
120119
}
121-
return;
120+
// Only early-exit when tokens were actually produced
121+
if (start_idx < end_idx) {
122+
return;
123+
}
122124
}
123125
}
124126
}
@@ -147,12 +149,21 @@ __global__ void ngram_match_gather_kernel(
147149

148150
int tid = threadIdx.x;
149151

150-
// Load tentative values from Phase 1
152+
// Load tentative values from Phase 1.
153+
// Encoder-active items are included in the scan with their original
154+
// seq_lens_this_time to match CPU threshold-budget accounting.
151155
int tentative = 0;
152156
int is_active = 0;
153157
if (tid < max_batch_size) {
154-
tentative = seq_lens_this_time_copy[tid];
155-
is_active = (tentative > 0) ? 1 : 0;
158+
if (seq_lens_encoder[tid] > 0) {
159+
// Encoder-active: contribute original token count to threshold budget.
160+
// seq_lens_this_time[tid] is still unmodified at this point.
161+
tentative = seq_lens_this_time[tid];
162+
is_active = 1;
163+
} else {
164+
tentative = seq_lens_this_time_copy[tid];
165+
is_active = (tentative > 0) ? 1 : 0;
166+
}
156167
}
157168

158169
// Scan 1: inclusive prefix sum of tentative token counts

fastdeploy/spec_decode/ngram.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def __init__(self, fd_config: "FDConfig"):
3737
super().__init__(fd_config)
3838
self.max_ngram_size = self.speculative_config.max_ngram_size
3939
self.input_ids_len = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cpu()
40-
self.input_ids_len_gpu = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64")
40+
self.input_ids_len_gpu = paddle.zeros(shape=[self.max_num_seqs, 1], dtype="int64").cuda()
4141

4242
def update(self, bid: int, seq_len: int):
4343
"""

tests/spec_decode/test_ngram_gpu_kernel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,7 +610,7 @@ def test_latency(self):
610610
t0 = time.perf_counter()
611611
for _ in range(n_runs):
612612
# Simulate old path: copy all tensors to CPU then back
613-
cpu_tensors = {k: paddle.to_tensor(v) for k, v in cpu_data.items()}
613+
cpu_tensors = {k: paddle.to_tensor(v, place=paddle.CPUPlace()) for k, v in cpu_data.items()}
614614
_ = cpu_tensors["draft_tokens"].cuda()
615615
_ = cpu_tensors["seq_lens_this_time"].cuda()
616616
paddle.device.synchronize()

0 commit comments

Comments
 (0)