Skip to content

[XPU] [Cherry-Pick] Unify Spec and non-spec branch.(#6947)#7180

Open
Jiajun-Ji wants to merge 7 commits intoPaddlePaddle:developfrom
Jiajun-Ji:mtp-unify-v4
Open

[XPU] [Cherry-Pick] Unify Spec and non-spec branch.(#6947)#7180
Jiajun-Ji wants to merge 7 commits intoPaddlePaddle:developfrom
Jiajun-Ji:mtp-unify-v4

Conversation

@Jiajun-Ji
Copy link
Copy Markdown
Contributor

@Jiajun-Ji Jiajun-Ji commented Apr 3, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

  • 输出长度未见明显异常
image

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 3, 2026 04:45
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 3, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 作为 #6685 的 Cherry-Pick,目标是在 XPU 后端统一 Spec / 非 Spec 分支的执行与后处理路径,并补齐 XPU 的 draft token 验证能力,从而与 GPU 侧的统一架构对齐。

Changes:

  • XPU ModelRunner 侧统一 speculative method 字段与 proposer 初始化/调用路径,并在后处理时接入 unified_update_model_status
  • XPU SpeculativeSampler 拆分 “naive 采样” 与 “verify + 采样” 路由,新增 verify_draft_tokens 调用链。
  • 新增 XPU 自定义算子 verify_draft_tokens(C++ wrapper + XPU3 kernel)及对应单测。

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
fastdeploy/worker/xpu_model_runner.py 统一 spec_method 命名与 proposer 初始化/运行逻辑;调整 share_inputs(新增 reasoning_status 等)并把后处理切到 unified_update_model_status
fastdeploy/model_executor/xpu_pre_and_post_process.py speculative 后处理由 speculate_update/speculate_set_value_by_flags_and_idx 迁移到 unified_update_model_status,并新增 is_naive_mode/prefill_one_step_stop 参数
fastdeploy/model_executor/layers/sample/sampler.py XPU speculative 采样路径重构:naive 采样 vs verify_draft_tokens 验证采样分流
custom_ops/xpu_ops/test/test_verify_draft_tokens.py 新增 verify_draft_tokens kernel 的参考实现对比测试
custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp 新增 verify_draft_tokens 的 XPU plugin wrapper(含 CPU wrapper 与 XPU3 launch)
custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu 新增 XPU3 verify_draft_tokens kernel 实现
custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h 导出 verify_draft_tokens plugin API 声明
custom_ops/xpu_ops/src/ops/pybind/pybind.cc 暴露 verify_draft_tokens 到 Python 侧
custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc 新增 verify_draft_tokens 的 Paddle 扩展 OP 封装与参数校验

Comment on lines +1455 to +1464
if self.spec_method is None:
self.proposer = None
return
self.proposer = self.spec_method.create_proposer(
self.fd_config,
main_model=self.get_model(),
local_rank=self.local_rank,
device_id=self.device_id,
share_inputs=self.share_inputs,
)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_init_speculative_proposer 现在对所有非 None 的 spec_method 都调用 create_proposer,这会导致 XPU 上选择 NGRAM/SUFFIX 时直接实例化对应 proposer:NgramProposer 依赖 fastdeploy.model_executor.ops.gpu.ngram_match(CUDA 实现),SuffixProposer 依赖 arctic_inference 包,这两者在 XPU 环境下大概率不可用,初始化阶段就会抛异常。建议在 XPUModelRunner 显式限制/拦截不支持的 spec_method(例如仅允许 NAIVE/MTP;对 NGRAM/SUFFIX raise NotImplementedError 或强制置为 NAIVE),并与配置校验逻辑保持一致。

Copilot uses AI. Check for mistakes.
Comment on lines +528 to +539
WRAPPER_CHECK_PTR(ctx, float, real_bsz, curand_states);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens);

WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, cu_seqlens_q_output);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, reasoning_status);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, max_dec_len);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, step_idx);
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的 verify_draft_tokens wrapper 里 WRAPPER_CHECK_PTR 的类型参数与真实指针类型不一致:seq_lens_this_time 是 int* 却按 float 检查;cu_seqlens_q_output/reasoning_status/max_dec_len/step_idx 也都被按 bool 检查。该问题会导致 wrapper 参数校验错误,严重时可能引发编译/运行期问题。建议按实际类型修正(seq_lens_this_time/cand_lens/seq_lens_encoder/… 用 int;cu_seqlens_q_output/reasoning_status 用 int;max_dec_len/step_idx 用 int64_t)。

Copilot uses AI. Check for mistakes.
output_token =
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
candidate_scores_now + i * max_candidate_len,
curand_states + i,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu_wrapper 的 TOPP Phase2 调用 topp_sampling_kernel 时使用 curand_states + i(i 为 token 位置),但 host 侧 VerifyDraftTokens 只分配了 bsz 个随机数(按 batch),在 i>0 时会越界读取,CPU 回退路径会出现未定义行为/崩溃。建议统一 curand_states 的语义(按 batch 还是按 step 位置),并确保分配长度与访问方式一致(例如改为 curand_states + bid 或分配 max_step_tokens 长度并在 XPU3 kernel 侧同样按 i 偏移)。

Suggested change
curand_states + i,
curand_states + bid,

Copilot uses AI. Check for mistakes.
Comment on lines +71 to +103
bool xpu_ctx_flag = true;
if (step_output_ids.is_cpu()) {
ctx = new api::Context(api::kCPU);
xpu_ctx_flag = false;
}

auto bsz = step_output_ids.shape()[0];
auto real_bsz = seq_lens_this_time.shape()[0];
auto max_step_tokens = step_input_ids.shape()[1];
auto end_length = end_tokens.shape()[0];
// max_candidate_len: 1 if candidate_ids not provided, else from shape
int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1;

// curand state: only needed for TOPP(0) strategy (stochastic sampling)
int random_seed = 0;
std::vector<int64_t> infer_seed(bsz, random_seed);
std::uniform_real_distribution<float> dist(0.0, 1.0);
std::vector<float> dev_curand_states_cpu;
for (int i = 0; i < bsz; i++) {
std::mt19937_64 engine(infer_seed[i]);
dev_curand_states_cpu.push_back(dist(engine));
}
float *dev_curand_states_xpu;
if (xpu_ctx_flag) {
xpu::ctx_guard RAII_GUARD(ctx);
dev_curand_states_xpu =
RAII_GUARD.alloc<float>(dev_curand_states_cpu.size());
xpu_memcpy(dev_curand_states_xpu,
dev_curand_states_cpu.data(),
dev_curand_states_cpu.size() * sizeof(float),
XPUMemcpyKind::XPU_HOST_TO_DEVICE);
}

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VerifyDraftTokens 中 dev_curand_states_xpu 仅在 xpu_ctx_flag==true 时分配;当 step_output_ids.is_cpu() 走 CPU ctx 路径时,该指针未初始化却仍传入 fastdeploy::plugin::verify_draft_tokens(同时 wrapper 侧还要求 curand_states 非空),会导致 CPU 路径不稳定/崩溃。建议:CPU 路径也分配/传入有效的 curand_states(或当 verify_strategy!=TOPP 时允许传 nullptr 并在 wrapper/kernel 侧放宽检查)。

Copilot uses AI. Check for mistakes.
Comment on lines +693 to +695
pass
# if not paddle.is_compiled_with_cuda():
# self.skipTest("Requires CUDA")
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该单测在 setUp 里没有根据 XPU 编译情况进行 skip(相关逻辑被注释掉),但文件顶部会在未编译 XPU 时把 CUDA_PLACE 退化为 CPUPlace,随后仍会调用 verify_draft_tokens 自定义算子,容易在无 XPU 环境下直接失败/崩溃。建议在 setUp 中恢复并改为检查 paddle.is_compiled_with_xpu()(或至少在非 XPU 环境 skip)。

Suggested change
pass
# if not paddle.is_compiled_with_cuda():
# self.skipTest("Requires CUDA")
if not paddle.is_compiled_with_xpu():
self.skipTest("Requires XPU")

Copilot uses AI. Check for mistakes.
Comment on lines +1047 to +1053
def _normal_sample_xpu(
self,
logits: paddle.Tensor,
probs: paddle.Tensor,
sampling_metadata: SamplingMetadata,
share_inputs: List[paddle.Tensor],
) -> SamplerOutput:
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 share_inputs 形参类型标注为 List[paddle.Tensor],但函数体内按 dict(share_inputs["seq_lens_this_time"] 等)使用。该不一致会误导调用方/IDE 类型提示,也容易掩盖真实的 key 约束。建议把 share_inputs 的类型统一为 Dict[str, paddle.Tensor](本文件其它路径同样如此),并相应更新 forward_xpu/_normal_sample_xpu/_verify_and_sample_xpu 等签名。

Copilot uses AI. Check for mistakes.
Comment on lines 1069 to +1075
fill_value=max_draft_token_num,
dtype="int32",
)
self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
self.share_inputs["output_padding_offset"] = paddle.full(
shape=[max_num_seqs * (max_draft_token_num + 1)],
fill_value=0,
dtype="int32",
)
# reasoning_status: per-sequence reasoning phase indicator
# 0=thinking, 1=emitting boundary, 2=response, 3=end
# verify_draft_tokens 在 reasoning_status==1 时强制拒绝所有 draft token
self.share_inputs["reasoning_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32")
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU speculative decoding 下 gather_next_token 的实现会通过 output_padding_offset 是否传入来决定走 eb_mtp_gather_next_token(输出形状为 token_num 维度)还是 eb_gather_next_token(输出形状为 bsz 维度)。当前 _init_share_inputs 删除了 share_inputs["output_padding_offset"] 的初始化,且 xpu_pre_process 也未再填充该字段,导致 speculative 模式下 xpu_process_output 大概率走错分支并产生错误形状/错误 hidden_states。建议恢复并维护 output_padding_offset(内容可为占位 tensor,只用于触发 MTP gather 路径),或改造 xpu_process_output/gather_next_token 的分支条件以显式使用 speculate_pre_process 的输出元信息。

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 3, 2026

Codecov Report

❌ Patch coverage is 8.33333% with 33 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@2632e6c). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/layers/sample/sampler.py 10.34% 26 Missing ⚠️
fastdeploy/worker/input_batch.py 0.00% 7 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7180   +/-   ##
==========================================
  Coverage           ?   73.84%           
==========================================
  Files              ?      376           
  Lines              ?    52933           
  Branches           ?     8259           
==========================================
  Hits               ?    39088           
  Misses             ?    11114           
  Partials           ?     2731           
Flag Coverage Δ
GPU 73.84% <8.33%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-03 15:40 CST

📋 Review 摘要

PR 概述:XPU 平台 speculative decoding 统一重构,新增 verify_draft_tokens kernel 实现 draft token 验证
变更范围custom_ops/xpu_ops/(新增C++/XPU kernel 及 wrapper)、Python 调用层
影响面 Tag[XPU] [Speculative Decoding] [OP]

📝 PR 规范检查

PR 标题符合规范,包含 [XPU][Cherry-Pick] 标签。描述中 Modifications 章节为空,建议补充具体变更内容。

描述建议(Modifications 章节):

## Modifications
- 新增 `verify_draft_tokens` XPU kernel,支持 TOPP/GREEDY/TARGET_MATCH 三种验证策略
- 添加 pybind 绑定和 Python 调用接口
- 新增单元测试 `test_verify_draft_tokens.py`

问题

级别 文件 概述
🔴 Bug verify_draft_token.cc:82 RAII guard 作用域问题导致 use-after-free
🔴 Bug verify_draft_tokens.cpp:562 WRAPPER_CHECK_PTR 类型参数错误

总体评价

PR 实现了 XPU 平台的 draft token 验证 kernel,整体架构清晰,支持多种验证策略。但存在 RAII 内存管理类型检查 两处严重 bug 需要修复后才能合入。

auto max_step_tokens = step_input_ids.shape()[1];
auto end_length = end_tokens.shape()[0];
// max_candidate_len: 1 if candidate_ids not provided, else from shape
int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug xpu::ctx_guard RAII 作用域问题导致 use-after-free

RAII_GUARDif 块结束时析构,会释放通过 alloc 分配的 dev_curand_states_xpu 内存。但该指针在第143行被传递给 fastdeploy::plugin::verify_draft_tokens 使用,此时内存已被释放。

此外,当 xpu_ctx_flag=false 时,dev_curand_states_xpu 未初始化就被使用。

建议将 RAII_GUARD 的作用域扩展到整个函数,或使用独立的内存分配方式:

float *dev_curand_states_xpu = nullptr;
xpu::ctx_guard RAII_GUARD(ctx);  // 移到 if 块外部
if (xpu_ctx_flag) {
    dev_curand_states_xpu =
        RAII_GUARD.alloc<float>(dev_curand_states_cpu.size());
    xpu_memcpy(...);
} else {
    dev_curand_states_xpu = dev_curand_states_cpu.data();  // CPU 模式直接使用 CPU 指针
}

seq_lens_encoder,
seq_lens_this_time,
end_tokens,
is_block_step,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug WRAPPER_CHECK_PTR 类型参数错误

第559-567行存在多处类型不匹配:

WRAPPER_CHECK_PTR(ctx, float, real_bsz, seq_lens_this_time);  // ❌ 应为 int
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, cu_seqlens_q_output);  // ❌ 应为 int  
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, reasoning_status);     // ❌ 应为 int
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, max_dec_len);          // ❌ 应为 int64_t
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, step_idx);             // ❌ 应为 int64_t

建议修复为:

WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int, real_bsz + 1, cu_seqlens_q_output);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, reasoning_status);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len);
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, step_idx);

…utput/batch_id_per_token_output, correct

  WRAPPER_CHECK_PTR types, and fix dynamic gather shape in verify_draft_tokens path.
Copilot AI review requested due to automatic review settings April 3, 2026 08:09
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 8 comments.

Comment on lines +138 to +139
self.spec_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.spec_method is not None
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前 PR 标题不符合仓库 Cherry-Pick 标题规范:模板要求以 [Cherry-Pick] 开头并在末尾附上原始 develop PR 编号(例如 [Cherry-Pick][XPU] ...(#6685))。目前标题中 [Cherry-Pick] 不是第一个 tag,且末尾编号看起来是当前 PR 而非原始 PR。建议按规范调整标题,并在描述的 Modifications/Usage 等小节补充关键信息以便回溯。

Copilot uses AI. Check for mistakes.
Comment on lines +1463 to +1472
if self.spec_method is None:
self.proposer = None
return
self.proposer = self.spec_method.create_proposer(
self.fd_config,
main_model=self.get_model(),
local_rank=self.local_rank,
device_id=self.device_id,
share_inputs=self.share_inputs,
)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_init_speculative_proposer 现在对除 None 之外的所有 SpecMethod 都会 create_proposer。但 SpecMethod.NGRAM/SUFFIX 的 Proposer 实现并非 XPU 专用(例如 NgramProposer 依赖 fastdeploy.model_executor.ops.gpu.ngram_match),在 XPU 路径下会导致运行时异常。建议在 XPU runner 里显式限制只支持的 method(如仅 MTP/NAIVE),对不支持的 method 直接报错或降级为 proposer=None(NAIVE)。

Copilot uses AI. Check for mistakes.
WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_step_tokens, step_input_ids);
// len(target_tokens) = cu_seqlens_q_output[-1]
WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, target_tokens);
WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, candidate_lens);
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

verify_draft_tokens 的 wrapper 参数校验里 candidate_lens 是 const int*,但这里用 WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, ...) 按 int64_t 做检查/转储,类型不匹配,容易导致校验/调试信息错误,甚至在某些实现下引发未定义行为。建议将类型参数改为 int,并与函数签名保持一致。

Suggested change
WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, candidate_lens);
WRAPPER_CHECK_PTR_OR_NULL(ctx, int, real_bsz, candidate_lens);

Copilot uses AI. Check for mistakes.
Comment on lines +325 to +328
output_token =
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
candidate_scores_now + i * max_candidate_len,
curand_states,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

XPU3 kernel 的 Phase2 TOPP 采样调用 topp_sampling_kernel 时直接传入 curand_states,没有按 bid(或 token position)做偏移,导致所有序列共享同一个随机数源(仅取 curand_states[0])。如果 curand_states 语义是 per-batch,这里应当使用 curand_states + bid;如果是 per-position,则应与 CPU wrapper 保持一致的偏移策略。建议修正偏移以避免跨 batch 采样相关性异常。

Suggested change
output_token =
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
candidate_scores_now + i * max_candidate_len,
curand_states,
// Use the RNG state for the current batch item to avoid
// cross-batch sampling correlation from sharing curand_states[0].
output_token =
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
candidate_scores_now + i * max_candidate_len,
curand_states + bid,

Copilot uses AI. Check for mistakes.
Comment on lines +129 to +139
if inputs["verify_strategy"] == 0: # TOPP — Phase 2 is stochastic
real_bsz = inputs["seq_lens_this_time"].shape[0]
for bid in range(real_bsz):
ref_len = int(step_output_len_ref[bid])
if ref_len > 1:
print(gpu_ids[bid, : ref_len - 1], step_output_ids_ref[bid, : ref_len - 1])
np.testing.assert_array_equal(
gpu_ids[bid, : ref_len - 1],
step_output_ids_ref[bid, : ref_len - 1],
err_msg=f"step_output_ids (accepted) mismatch at bid={bid} ({label})",
)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test_verify_draft_tokens.py 里 compare_results 在 TOPP 分支无条件 print,会污染单测输出并影响 CI 日志可读性。建议移除该 print,或用可控的 debug flag/日志级别包裹。

Copilot uses AI. Check for mistakes.
Comment on lines +690 to +696
class TestVerifyDraftTokens(unittest.TestCase):

def setUp(self):
pass
# if not paddle.is_compiled_with_cuda():
# self.skipTest("Requires CUDA")

Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该单测 setUp() 目前不做任何 skip,但文件顶部直接 import 了 fastdeploy.model_executor.ops.xpu.verify_draft_tokens;在未编译/未加载 XPU 自定义算子(或无 XPU 环境)的 CI 上可能会在 import 阶段直接失败,导致无法 skip。建议参考同目录 test_unified_update_model_status 的做法:在 setUp 或模块级别先判断 paddle.is_compiled_with_xpu() / 算子是否可用,不满足则 skipTest。

Copilot uses AI. Check for mistakes.
cu_seqlens_q_output = np.zeros(real_bsz + 1, dtype=np.int32)
for i in range(real_bsz):
cu_seqlens_q_output[i + 1] = cu_seqlens_q_output[i] + seq_lens_this_time[i]
cu_seqlens_q_output = cu_seqlens_q_output[:real_bsz].astype(np.int32)
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

输入生成里 cu_seqlens_q_output 先构造了长度 real_bsz+1 的前缀和数组,但随后又切片成 [:real_bsz],丢弃了末尾 total(而 kernel wrapper 侧还按 real_bsz+1 做了校验/假设)。这会导致测试输入与真实运行时 share_inputs["cu_seqlens_q_output"](通常为 bsz+1)不一致,且可能触发越界/校验失败。建议保留 full length(real_bsz+1),或同步调整 kernel/wrapper 的期望。

Suggested change
cu_seqlens_q_output = cu_seqlens_q_output[:real_bsz].astype(np.int32)
cu_seqlens_q_output = cu_seqlens_q_output.astype(np.int32, copy=False)

Copilot uses AI. Check for mistakes.
)
# reasoning_status: per-sequence reasoning phase indicator
# 0=thinking, 1=emitting boundary, 2=response, 3=end
# verify_draft_tokens 在 reasoning_status==1 时强制拒绝所有 draft token
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里新增的代码注释包含中文("verify_draft_tokens 在 reasoning_status==1 时..."),但仓库约定代码注释应使用英文,便于跨团队维护与国际化协作。建议将该注释翻译为英文并保持与周围注释语言一致。

Suggested change
# verify_draft_tokens 在 reasoning_status==1 时强制拒绝所有 draft token
# verify_draft_tokens forcibly rejects all draft tokens when reasoning_status == 1

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown

@fastdeploy-bot fastdeploy-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-03 17:31 CST

📋 Review 摘要

PR 概述:Cherry-Pick PR,统一 XPU 平台的 Spec 和 non-spec 分支,新增 verify_draft_tokens 内核实现推测解码的 draft token 验证功能。

变更范围custom_ops/xpu_ops/(新增内核实现)、fastdeploy/model_executor/(Python 集成)

影响面 TagXPU Speculative Decoding OP

问题

级别 文件 概述
🟡 建议 verify_draft_token.cc:85 随机种子硬编码为 0,TOPP 采样结果固定
🟡 建议 verify_draft_tokens.xpu:23 ClusterReduce 函数未被调用(死代码)
🟡 建议 verify_draft_tokens.xpu:62 xorwow 函数未被调用(死代码)
🟡 建议 verify_draft_tokens.xpu:75 tid 变量赋值后未使用

总体评价

代码实现完整,包含 XPU 内核、CPU wrapper 和单元测试。主要问题是存在从其他内核复制但未使用的代码,以及随机种子硬编码可能影响 TOPP 策略的采样多样性。建议在后续版本中清理死代码。

int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1;

// curand state: only needed for TOPP(0) strategy (stochastic sampling)
int random_seed = 0;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 随机种子硬编码为 0,导致 TOPP 采样结果固定

当前实现中 random_seed = 0 且所有 batch 使用相同种子初始化 std::mt19937_64,这意味着:

  1. 每次调用生成的随机数序列相同
  2. 所有 batch 的 TOPP 采样行为一致

如果这是有意为之(如用于调试/复现),建议添加注释说明。否则建议:

  • 从外部传入随机种子
  • 或使用时间戳/请求 ID 作为种子来源

res = vextract_int32x16(v1, 1);
return res;
}
static inline __device__ int ClusterReduce(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 ClusterReducev_reduce 函数在此内核中未被调用

这两个函数(第 9-38 行)是从 speculate_update.xpu 等内核复制过来的,但在当前 verify_draft_tokens 内核中并未使用。建议移除这些死代码以保持代码整洁。

return false;
}

static __device__ inline unsigned int xorwow(unsigned int &state) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 xorwow 函数未被调用

此随机数生成函数已定义但从未在内核中使用。当前 TOPP 采样使用的是从 host 端传入的 curand_states,而非内核内生成。建议移除此死代码。

__global_ptr__ const float *dev_curand_states,
const int candidate_len,
const float topp) {
const int tid = core_id();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 未使用的变量 tid

const int tid = core_id();

此变量被赋值但后续未使用,建议移除以消除编译器警告。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants