diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu index 18aa5d53d21..e620e914a25 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu @@ -34,7 +34,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( int64_t* next_tokens, // [bs, tokens_per_step] const int* max_think_lens, // [bs] int* max_reply_lens, // [bs] - int64_t* step_idx, // [bs] + const int64_t* step_idx, // [bs] const int64_t* eos_token_ids, // [eos_len] int* limit_status, // [bs] int* accept_num, // [bs] @@ -68,7 +68,7 @@ __global__ void speculate_limit_thinking_content_length_kernel( int new_accept_num = original_accept_num; // 本 step 的 token offset 对应的绝对 step - const int64_t current_base_step = step_idx[bid] - original_accept_num + 1; + const int64_t current_base_step = step_idx[bid] + 1; for (int token_offset = 0; token_offset < original_accept_num; token_offset++) { @@ -100,8 +100,8 @@ __global__ void speculate_limit_thinking_content_length_kernel( // inject_token_ids[0]) if (status == 0 && (current_step - 1) == - max_think_len) { // current_step - 1 是因为 speculate_verify 里 - // step_idx + 1 了 + max_think_len) { // current_step - 1 : 已输出 current_step-1 + // 个thinking token status = (inject_len > 0) ? 1 : done_status; } } else if (max_think_len == 0) { @@ -181,13 +181,6 @@ __global__ void speculate_limit_thinking_content_length_kernel( } } - // 更新 step_idx / accept_num(被截断的 token 需要回退 - // step_idx) - const int discarded_tokens = original_accept_num - new_accept_num; - if (discarded_tokens > 0) { - step_idx[bid] -= discarded_tokens; - } - accept_num[bid] = new_accept_num; limit_status[bid] = status; max_reply_lens[bid] = max_reply_len; @@ -221,7 +214,7 @@ void SpeculateLimitThinkingContentLength( const_cast(next_tokens.data()), max_think_lens.data(), const_cast(max_reply_lens.data()), - const_cast(step_idx.data()), + step_idx.data(), eos_token_ids.data(), const_cast(limit_status.data()), const_cast(accept_num.data()), diff --git a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu index ee364884e96..b0728854ef6 100644 --- a/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu +++ b/custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu @@ -51,54 +51,74 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, const int64_t step_idx_now = step_idx[bid]; const int64_t min_token_limit = min_tokens[bid]; - const bool can_stop = (step_idx_now >= min_token_limit); + const bool can_stop = (step_idx_now + accept_num >= min_token_limit); if (!can_stop) return; if (!stop_flags[bid]) { int accept_idx = 0; bool is_end = false; - // 遍历起始位置 - for (; accept_idx <= accept_num - 1 && !is_end; accept_idx++) { + + // 首先检查 pre_ids 最后一个位置是否是 stop_seq 的最后一个 token + // 处理上一轮 accept_tokens 最后一个位置的 stop_seq 的情况 + if (step_idx_now >= stop_seq_len) { + bool pre_ids_end = true; + for (int i = stop_seq_len - 1; i >= 0; --i) { + int pre_ids_idx = step_idx_now - (stop_seq_len - i); + if (pre_ids_idx < 0 || pre_ids_now[pre_ids_idx] != stop_seq_now[i]) { + pre_ids_end = false; + break; + } + } + if (pre_ids_end) { + // stop_seq 在上一轮已经完整出现在 pre_ids 末尾 + // 本轮不需要接受任何 token + accept_nums[bid] = 0; + return; + } + } + + // 遍历窗口结束位置,只检查到 accept_num - 2(倒数第二个位置) + // 这样如果 stop_seq 最后一个 token 在 accept_tokens 最后一位,下一轮处理 + // 避免 accept_tokens[accept_idx - 1] = eos 时越界 + for (; accept_idx <= accept_num - 2 && !is_end; accept_idx++) { if (step_idx_now + accept_idx + 1 < stop_seq_len) { #ifdef DEBUG_SPEC_STOP_SEQS printf("num %d < stop_seq_len %d\n", - step_idx_now - accept_num + accept_idx + 1, + step_idx_now + accept_idx + 1, stop_seq_len); #endif continue; } - // 遍历一个 stop_seqs + // 遍历一个 stop_seqs,从最后一个 token 开始比较 for (int i = stop_seq_len - 1; i >= 0; --i) { int64_t cur_token_idx = -1; - // 通过当前值判断 token 是在 pre_ids 还是 accept_token 里 - if (stop_seq_len - 1 - i < accept_idx) { + int offset = stop_seq_len - 1 - i; + int accept_tokens_idx = accept_idx - offset; + + if (accept_tokens_idx >= 0) { #ifdef DEBUG_SPEC_STOP_SEQS printf( "AcceptTokens bid:%d. tid:%d, accept_idx:%d, " - "accept_token_idx: " - "%d\n", + "accept_token_idx: %d\n", bid, tid, accept_idx, - accept_idx - (stop_seq_len - 1 - i) - 1); + accept_tokens_idx); #endif - cur_token_idx = - accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1]; + cur_token_idx = accept_tokens_now[accept_tokens_idx]; } else { + // 需要从 pre_ids 读取 + int pre_ids_idx = step_idx_now + accept_tokens_idx; #ifdef DEBUG_SPEC_STOP_SEQS printf( "PreIds bid:%d. tid:%d, step_idx_now:%ld. " - "accept_idx:%d. " - "pre_id_idx: %ld\n", + "accept_idx:%d. pre_id_idx: %d\n", bid, tid, step_idx_now, accept_idx, - step_idx_now - accept_num + accept_idx - - (stop_seq_len - 1 - i)); + pre_ids_idx); #endif - int pre_ids_idx = - step_idx_now + accept_idx - (stop_seq_len - 1 - i); // EC3 // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, // 导致异常结束 @@ -128,10 +148,11 @@ __global__ void spec_set_value_by_stop_seqs(bool *stop_flags, #ifdef DEBUG_SPEC_STOP_SEQS printf("bid:%d end with accept_idx %d", bid, accept_idx); #endif - + // 循环退出后 accept_idx 已经递增,指向 stop_seq 最后一个 token + // 的下一个位置 accept_idx - 1 是 stop_seq 最后一个 token 的位置 将 + // stop_seq 最后一个 token 替换为 eos,与非 MTP 行为对齐 accept_nums[bid] = accept_idx; accept_tokens_now[accept_idx - 1] = end_ids[0]; - // stop_flags[bid] = true; } } } diff --git a/tests/operators/test_speculate_set_stop_value_multi_seqs.py b/tests/operators/test_speculate_set_stop_value_multi_seqs.py index 45d8a0ef34f..5d61dd5548a 100644 --- a/tests/operators/test_speculate_set_stop_value_multi_seqs.py +++ b/tests/operators/test_speculate_set_stop_value_multi_seqs.py @@ -42,7 +42,7 @@ def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: return paddle_inputs -def run_kernel(paddle_inputs, inputs): +def run_kernel(paddle_inputs): """Call the CUDA kernel.""" speculate_set_stop_value_multi_seqs( paddle_inputs["accept_tokens"], @@ -137,7 +137,11 @@ def gen_inputs( def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str, Any]: - """Python reference — must match CUDA kernel logic exactly.""" + """Python reference — must match CUDA kernel logic exactly. + + New semantics: accept_idx represents the window's end position (where stop_seq's last token lands). + Stop sequence last token is replaced with eos to align with non-MTP behavior. + """ accept_tokens = inputs["accept_tokens"].copy() accept_num = inputs["accept_num"].copy() stop_flags = inputs["stop_flags"].copy() @@ -166,7 +170,7 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str step_idx_now = int(step_idx[bid]) min_token_limit = int(min_tokens[bid]) - can_stop = step_idx_now >= min_token_limit + can_stop = step_idx_now + an >= min_token_limit if not can_stop: continue if stop_flags[bid]: @@ -174,18 +178,41 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str accept_idx = 0 is_end = False - while accept_idx <= an - 1 and not is_end: + + # First check if stop_seq ended in pre_ids (from previous round) + # This handles the case where stop_seq's last token was at the last + # position of the previous round's accept_tokens + if step_idx_now >= stop_seq_len: + pre_ids_end = True + for i in range(stop_seq_len - 1, -1, -1): + pre_ids_idx = step_idx_now - (stop_seq_len - i) + if pre_ids_idx < 0 or pre_ids_now[pre_ids_idx] != stop_seq_now[i]: + pre_ids_end = False + break + if pre_ids_end: + # stop_seq already complete in pre_ids, accept nothing this round + accept_num[bid] = 0 + continue + + # accept_idx is the window end position (stop_seq last token position) + # Only check up to accept_num - 2 to avoid eos write out of bounds + while accept_idx <= an - 2 and not is_end: if step_idx_now + accept_idx + 1 < stop_seq_len: accept_idx += 1 continue # Check one stop_seq match + # offset = stop_seq_len - 1 - i + # accept_tokens_idx = accept_idx - offset for i in range(stop_seq_len - 1, -1, -1): + offset = stop_seq_len - 1 - i + accept_tokens_idx = accept_idx - offset cur_token_idx = -1 - if stop_seq_len - 1 - i < accept_idx: - cur_token_idx = accept_tokens_now[accept_idx - (stop_seq_len - 1 - i) - 1] + + if accept_tokens_idx >= 0: + cur_token_idx = accept_tokens_now[accept_tokens_idx] else: - pre_ids_idx = step_idx_now + accept_idx - (stop_seq_len - 1 - i) + pre_ids_idx = step_idx_now + accept_tokens_idx if pre_ids_idx <= 0: break cur_token_idx = pre_ids_now[pre_ids_idx] @@ -199,9 +226,10 @@ def reference_spec_set_stop_value_multi_seqs(inputs: Dict[str, Any]) -> Dict[str accept_idx += 1 if is_end: + # accept_idx points to stop_seq last token position (after loop increment) + # Replace stop_seq last token with eos, aligning with non-MTP behavior accept_num[bid] = accept_idx accept_tokens[bid, accept_idx - 1] = end_ids[0] - # stop_flags[bid] = True # kernel no longer sets stop_flags return { "accept_tokens": accept_tokens, @@ -239,7 +267,7 @@ class TestSpeculateSetStopValueMultiSeqs(unittest.TestCase): def _run_and_get(self, inputs): paddle_inputs = to_paddle_inputs(inputs) - run_kernel(paddle_inputs, inputs) + run_kernel(paddle_inputs) return get_outputs(paddle_inputs) def _check_all_outputs(self, inputs, outputs): @@ -264,7 +292,7 @@ def test_configs(self): self._run_full_test(test_cfg) def test_match_in_accept_tokens_only(self): - """Stop seq found entirely within accept_tokens.""" + """Stop seq found entirely within accept_tokens. Last token replaced with eos.""" inputs = gen_inputs(real_bsz=1, accept_tokens_len=5, stop_seqs_bs=1, stop_seqs_max_len=3, seed=10) # Place stop seq [A, B, C] at accept_tokens positions [0,1,2] inputs["accept_num"][:] = 4 @@ -276,9 +304,13 @@ def test_match_in_accept_tokens_only(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # stop_seq [10, 20, 30] matches at accept_idx=2 (window ends at accept_tokens[2]=30) + # After loop, accept_idx=3, accept_num=3, accept_tokens[2] replaced with eos (-1) + self.assertEqual(outputs["accept_num"][0], 3) + self.assertEqual(outputs["accept_tokens"][0, 2], -1) # replaced with eos def test_match_spanning_pre_ids_and_accept(self): - """Stop seq spans token_ids_all (pre_ids) and accept_tokens.""" + """Stop seq spans token_ids_all (pre_ids) and accept_tokens. Last token replaced with eos.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -290,12 +322,12 @@ def test_match_spanning_pre_ids_and_accept(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 6 inputs["accept_num"][:] = 3 - # Kernel matching at accept_idx=2 (3rd token, 0-indexed): - # i=2(last): stop_seq_len-1-i=0 < accept_idx(2) -> accept_tokens[2-0-1]=accept_tokens[1] - # i=1: stop_seq_len-1-i=1 < accept_idx(2) -> accept_tokens[2-1-1]=accept_tokens[0] - # i=0: stop_seq_len-1-i=2 >= accept_idx(2) -> pre_ids[step_idx+2-(3-1-0)]=pre_ids[6] - # So stop_seq should be [pre_ids[6], accept_tokens[0], accept_tokens[1]] - inputs["token_ids_all"][0, 6] = 99 + # stop_seq = [99, 11, 22] (len=3) + # At accept_idx=1 (window ends at accept_tokens[1]=22): + # i=2: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=22 vs stop_seq[2]=22 ✓ + # i=1: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=11 vs stop_seq[1]=11 ✓ + # i=0: offset=2, accept_tokens_idx=-1 -> pre_ids[step_idx+(-1)]=pre_ids[5]=99 vs stop_seq[0]=99 ✓ + inputs["token_ids_all"][0, 5] = 99 inputs["accept_tokens"][0, :3] = [11, 22, 33] inputs["stop_seqs"][0, 0, :3] = [99, 11, 22] inputs["stop_seqs_len"][0, 0] = 3 @@ -303,12 +335,15 @@ def test_match_spanning_pre_ids_and_accept(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - # Match at accept_idx=2, loop increments to 3 - self.assertEqual(outputs["accept_num"][0], 3) - self.assertEqual(outputs["accept_tokens"][0, 2], -1) - - def test_match_in_pre_ids_only(self): - """Stop seq found entirely within token_ids_all (pre_ids), matching at accept_idx=0.""" + # Match at accept_idx=1, loop increments to 2 -> accept_num=2 + # accept_tokens[1] replaced with eos (-1) + self.assertEqual(outputs["accept_num"][0], 2) + self.assertEqual(outputs["accept_tokens"][0, 1], -1) # replaced with eos + + def test_match_in_pre_ids_only_not_detected(self): + """Stop seq ending purely in pre_ids history is NOT detected. + accept_idx represents the window end position (stop_seq last token), + so sequences ending before the first accepted token are not checked.""" inputs = gen_inputs( real_bsz=1, accept_tokens_len=5, @@ -320,12 +355,7 @@ def test_match_in_pre_ids_only(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # pre_ids at step_idx positions: token_ids_all[0, 6]=50, [0,7]=60, [0,8]=70 - # stop_seq = [50, 60, 70], all 3 tokens are in pre_ids - # For accept_idx=0: step_idx_now + 0 + 1 = 9 >= stop_seq_len=3, so we check - # i=2: pre_ids_idx = 8+0-(3-1-2) = 8 -> pre_ids_now[8] = 70 - # i=1: pre_ids_idx = 8+0-(3-1-1) = 7 -> pre_ids_now[7] = 60 - # i=0: pre_ids_idx = 8+0-(3-1-0) = 6 -> pre_ids_now[6] = 50 + # stop_seq [50, 60, 70] lives entirely in pre_ids — no accept token matches it inputs["token_ids_all"][0, 6] = 50 inputs["token_ids_all"][0, 7] = 60 inputs["token_ids_all"][0, 8] = 70 @@ -336,7 +366,8 @@ def test_match_in_pre_ids_only(self): inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) - self.assertEqual(outputs["accept_num"][0], 1) + # No match: accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 3) def test_already_stopped(self): """Kernel skips sequences with stop_flags=True.""" @@ -371,7 +402,7 @@ def test_min_tokens_blocks_stop(self): inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] inputs["stop_seqs_len"][0, 0] = 3 inputs["stop_flags"][:] = False - inputs["min_tokens"][:] = 100 # step_idx=8 < 100, should NOT stop + inputs["min_tokens"][:] = 100 # step_idx+accept_num=11 < 100, should NOT stop outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) @@ -388,15 +419,16 @@ def test_min_tokens_allows_stop(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # Put stop_seq entirely in pre_ids (same pattern as test_match_in_pre_ids_only) - inputs["token_ids_all"][0, 6] = 50 - inputs["token_ids_all"][0, 7] = 60 - inputs["token_ids_all"][0, 8] = 70 - inputs["accept_tokens"][0, :3] = [1, 2, 3] - inputs["stop_seqs"][0, 0, :3] = [50, 60, 70] - inputs["stop_seqs_len"][0, 0] = 3 + # stop_seq [X, 50] spans pre_ids and accept_tokens[0]. + # At accept_idx=0 (window ends at accept_tokens[0]=50): + # i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=50 vs stop_seq[1]=50 ✓ + # i=0: offset=1, accept_tokens_idx=-1 -> pre_ids[8+(-1)]=pre_ids[7] + pre_val = int(inputs["token_ids_all"][0, 7]) + inputs["accept_tokens"][0, :3] = [50, 60, 70] + inputs["stop_seqs"][0, 0, :2] = [pre_val, 50] + inputs["stop_seqs_len"][0, 0] = 2 inputs["stop_flags"][:] = False - inputs["min_tokens"][:] = 5 # step_idx=8 >= 5, should stop + inputs["min_tokens"][:] = 5 # step_idx+accept_num=11 >= 5, should stop outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) @@ -413,20 +445,24 @@ def test_multiple_stop_seqs_second_matches(self): inputs["prompt_lens"][:] = 0 inputs["step_idx"][:] = 8 inputs["accept_num"][:] = 3 - # accept_tokens: stop_seq[20,30] matches at accept_idx=2: - # i=1: accept_tokens[2-0-1]=accept_tokens[1]=30 vs stop_seq[1]=30 OK - # i=0: accept_tokens[2-1-1]=accept_tokens[0]=20 vs stop_seq[0]=20 OK + # accept_tokens: [20, 30, 40] + # Second stop seq [20, 30] matches at accept_idx=1 (window ends at accept_tokens[1]=30): + # i=1: offset=0, accept_tokens_idx=1 -> accept_tokens[1]=30 vs stop_seq[1]=30 ✓ + # i=0: offset=1, accept_tokens_idx=0 -> accept_tokens[0]=20 vs stop_seq[0]=20 ✓ inputs["accept_tokens"][0, :3] = [20, 30, 40] # First stop seq doesn't match inputs["stop_seqs"][0, 0, :3] = [99, 98, 97] inputs["stop_seqs_len"][0, 0] = 3 - # Second stop seq matches + # Second stop seq [20, 30] matches inputs["stop_seqs"][0, 1, :2] = [20, 30] inputs["stop_seqs_len"][0, 1] = 2 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # Match at accept_idx=1 -> accept_num=2, accept_tokens[1] replaced with eos + self.assertEqual(outputs["accept_num"][0], 2) + self.assertEqual(outputs["accept_tokens"][0, 1], -1) # replaced with eos def test_nonzero_prompt_lens(self): """Verify prompt_lens offset is applied correctly.""" @@ -444,19 +480,99 @@ def test_nonzero_prompt_lens(self): inputs["accept_num"][:] = 2 inputs["accept_tokens"][0, :2] = [55, 66] # pre_ids_now starts at token_ids_all[0, prompt_len:] - # stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx] - # For accept_idx=0: pre_ids_idx = step_idx + 0 - (2-1-0) = 5-1 = 4 - # -> pre_ids_now[4] = token_ids_all[0, prompt_len + 4] - # For accept_idx=1 (second token is accept_tokens[0,0]=55): - # i=1: accept_tokens_now[1-(2-1-1)-1] = accept_tokens_now[0] = 55 - # i=0: pre_ids_idx = step_idx + 1 - (2-1-0) = 5+1-1 = 5 -> pre_ids_now[5] - target_val = int(inputs["token_ids_all"][0, prompt_len + 5]) + # stop_seq = [X, 55] where X = token_ids_all[0, prompt_len + step_idx - 1] + # At accept_idx=0 (window ends at accept_tokens[0]=55): + # i=1: offset=0, accept_tokens_idx=0 -> accept_tokens[0]=55 vs stop_seq[1]=55 ✓ + # i=0: offset=1, accept_tokens_idx=-1 -> pre_ids[step_idx+(-1)]=pre_ids[4]=token_ids_all[0, prompt_len+4] + target_val = int(inputs["token_ids_all"][0, prompt_len + 4]) inputs["stop_seqs"][0, 0, :2] = [target_val, 55] inputs["stop_seqs_len"][0, 0] = 2 inputs["stop_flags"][:] = False inputs["min_tokens"][:] = 0 outputs = self._run_and_get(inputs) self._check_all_outputs(inputs, outputs) + # Match at accept_idx=0 -> accept_num=1, accept_tokens[0] replaced with eos + self.assertEqual(outputs["accept_num"][0], 1) + self.assertEqual(outputs["accept_tokens"][0, 0], -1) # replaced with eos + + def test_single_token_stop_seq_preserved(self): + """Single token stop_seq (like <|im_end|>) last token replaced with eos.""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=90, + ) + inputs["prompt_lens"][:] = 0 + inputs["step_idx"][:] = 10 + inputs["accept_num"][:] = 4 + # accept_tokens: [a, b, <|im_end|>, d] where <|im_end|> has token id 999 + inputs["accept_tokens"][0, :4] = [100, 200, 999, 300] + # stop_seq = [<|im_end|>] (single token) + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # Match at accept_idx=2 (window ends at accept_tokens[2]=999) + # After loop increment, accept_idx=3, accept_num=3, accept_tokens[2] replaced with eos + self.assertEqual(outputs["accept_num"][0], 3) + self.assertEqual(outputs["accept_tokens"][0, 2], -1) # replaced with eos + + def test_stop_seq_at_last_position_not_detected(self): + """Stop seq at the last position of accept_tokens is NOT detected (deferred to next round).""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=100, + ) + inputs["prompt_lens"][:] = 0 + inputs["step_idx"][:] = 10 + inputs["accept_num"][:] = 4 + # stop_seq [999] is at accept_tokens[3] (last valid position) + # Since we only check up to accept_num - 2 = 2, this won't be detected + inputs["accept_tokens"][0, :4] = [100, 200, 300, 999] + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # No match because accept_idx only goes up to 2, and 999 is at position 3 + # accept_num unchanged + self.assertEqual(outputs["accept_num"][0], 4) + + def test_stop_seq_detected_from_previous_round(self): + """Stop seq at the end of pre_ids (from previous round) is detected.""" + inputs = gen_inputs( + real_bsz=1, + accept_tokens_len=5, + max_model_len=32, + stop_seqs_bs=1, + stop_seqs_max_len=1, + seed=110, + ) + inputs["prompt_lens"][:] = 0 + # Simulate previous round: stop_seq [999] is at pre_ids[9] + # step_idx = 10 means pre_ids has indices 0-9 + inputs["step_idx"][:] = 10 + inputs["token_ids_all"][0, 9] = 999 # pre_ids last position + inputs["accept_num"][:] = 3 + inputs["accept_tokens"][0, :3] = [100, 200, 300] + inputs["stop_seqs"][0, 0, 0] = 999 + inputs["stop_seqs_len"][0, 0] = 1 + inputs["stop_flags"][:] = False + inputs["min_tokens"][:] = 0 + outputs = self._run_and_get(inputs) + self._check_all_outputs(inputs, outputs) + # stop_seq [999] was in pre_ids, so accept_num = 0 (accept nothing this round) + self.assertEqual(outputs["accept_num"][0], 0) if __name__ == "__main__":