Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
int64_t* next_tokens, // [bs, tokens_per_step]
const int* max_think_lens, // [bs]
int* max_reply_lens, // [bs]
int64_t* step_idx, // [bs]
const int64_t* step_idx, // [bs]
const int64_t* eos_token_ids, // [eos_len]
int* limit_status, // [bs]
int* accept_num, // [bs]
Expand Down Expand Up @@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel(
int new_accept_num = original_accept_num;

// 本 step 的 token offset 对应的绝对 step
const int64_t current_base_step = step_idx[bid] - original_accept_num + 1;
const int64_t current_base_step = step_idx[bid] + 1;

for (int token_offset = 0; token_offset < original_accept_num;
token_offset++) {
Expand Down Expand Up @@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel(
// inject_token_ids[0])
if (status == 0 &&
(current_step - 1) ==
max_think_len) { // current_step - 1 是因为 speculate_verify 里
// step_idx + 1 了
max_think_len) { // current_step - 1 : 已输出 current_step-1
// 个thinking token
status = (inject_len > 0) ? 1 : done_status;
}
} else if (max_think_len == 0) {
Expand Down Expand Up @@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel(
}
}

// 更新 step_idx / accept_num(被截断的 token 需要回退
// step_idx)
const int discarded_tokens = original_accept_num - new_accept_num;
if (discarded_tokens > 0) {
step_idx[bid] -= discarded_tokens;
}

accept_num[bid] = new_accept_num;
limit_status[bid] = status;
max_reply_lens[bid] = max_reply_len;
Expand Down Expand Up @@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength(
const_cast<int64_t*>(next_tokens.data<int64_t>()),
max_think_lens.data<int>(),
const_cast<int*>(max_reply_lens.data<int>()),
const_cast<int64_t*>(step_idx.data<int64_t>()),
step_idx.data<int64_t>(),
eos_token_ids.data<int64_t>(),
const_cast<int*>(limit_status.data<int>()),
const_cast<int*>(accept_num.data<int>()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,54 +51,74 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
const int64_t step_idx_now = step_idx[bid];
const int64_t min_token_limit = min_tokens[bid];

const bool can_stop = (step_idx_now >= min_token_limit);
const bool can_stop = (step_idx_now + accept_num >= min_token_limit);
if (!can_stop) return;
if (!stop_flags[bid]) {
int accept_idx = 0;
bool is_end = false;
// 遍历起始位置
for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) {

// 首先检查 pre_ids 最后一个位置是否是 stop_seq 的最后一个 token
// 处理上一轮 accept_tokens 最后一个位置的 stop_seq 的情况
if (step_idx_now >= stop_seq_len) {
bool pre_ids_end = true;
for (int i = stop_seq_len - 1; i >= 0; --i) {
int pre_ids_idx = step_idx_now - (stop_seq_len - i);
if (pre_ids_idx < 0 || pre_ids_now[pre_ids_idx] != stop_seq_now[i]) {
pre_ids_end = false;
break;
}
}
if (pre_ids_end) {
// stop_seq 在上一轮已经完整出现在 pre_ids 末尾
// 本轮不需要接受任何 token
accept_nums[bid] = 0;
return;
}
}

// 遍历窗口结束位置,只检查到 accept_num - 2(倒数第二个位置)
// 这样如果 stop_seq 最后一个 token 在 accept_tokens 最后一位,下一轮处理
// 避免 accept_tokens[accept_idx - 1] = eos 时越界
for (; accept_idx <= accept_num - 2 && !is_end; accept_idx++) {
if (step_idx_now + accept_idx + 1 < stop_seq_len) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf("num %d < stop_seq_len %d\n",
step_idx_now - accept_num + accept_idx + 1,
step_idx_now + accept_idx + 1,
stop_seq_len);
#endif
continue;
}
// 遍历一个 stop_seqs
// 遍历一个 stop_seqs,从最后一个 token 开始比较
for (int i = stop_seq_len - 1; i >= 0; --i) {
int64_t cur_token_idx = -1;

// 通过当前值判断 token 是在 pre_ids 还是 accept_token 里
if (stop_seq_len - 1 - i < accept_idx) {
int offset = stop_seq_len - 1 - i;
int accept_tokens_idx = accept_idx - offset;

if (accept_tokens_idx >= 0) {
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"AcceptTokens bid:%d. tid:%d, accept_idx:%d, "
"accept_token_idx: "
"%d\n",
"accept_token_idx: %d\n",
bid,
tid,
accept_idx,
accept_idx - (stop_seq_len - 1 - i) - 1);
accept_tokens_idx);
#endif
cur_token_idx =
accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1];
cur_token_idx = accept_tokens_now[accept_tokens_idx];
} else {
// 需要从 pre_ids 读取
int pre_ids_idx = step_idx_now + accept_tokens_idx;
#ifdef DEBUG_SPEC_STOP_SEQS
printf(
"PreIds bid:%d. tid:%d, step_idx_now:%ld. "
"accept_idx:%d. "
"pre_id_idx: %ld\n",
"accept_idx:%d. pre_id_idx: %d\n",
bid,
tid,
step_idx_now,
accept_idx,
step_idx_now - accept_num + accept_idx -
(stop_seq_len - 1 - i));
pre_ids_idx);
#endif
int pre_ids_idx =
step_idx_now + accept_idx - (stop_seq_len - 1 - i);
// EC3
// 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23,
// 导致异常结束
Expand Down Expand Up @@ -128,10 +148,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags,
#ifdef DEBUG_SPEC_STOP_SEQS
printf("bid:%d end with accept_idx %d", bid, accept_idx);
#endif

// 循环退出后 accept_idx 已经递增,指向 stop_seq 最后一个 token
// 的下一个位置 accept_idx - 1 是 stop_seq 最后一个 token 的位置 将
// stop_seq 最后一个 token 替换为 eos,与非 MTP 行为对齐
accept_nums[bid] = accept_idx;
accept_tokens_now[accept_idx - 1] = end_ids[0];
// stop_flags[bid] = true;
}
}
}
Expand Down
Loading
Loading