Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 8 additions & 14 deletions custom_ops/xpu_ops/src/ops/gather_next_token.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
bool is_speculative,
int max_bsz) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
Expand Down Expand Up @@ -73,7 +73,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const_cast<int32_t*>(decoder_batch_map.data<int32_t>())};

paddle::Tensor out;
if (output_padding_offset) {
if (is_speculative) {
int need_delete_token_num = 0;
if (enc_batch > 0) {
need_delete_token_num =
Expand All @@ -88,7 +88,7 @@ std::vector<paddle::Tensor> GatherNextToken(
return {out};
}

if (output_padding_offset) {
if (is_speculative) {
int r = fastdeploy::plugin::eb_mtp_gather_next_token<XPUType, XPUType>(
ctx,
reinterpret_cast<const XPUType*>(x.data<data_t>()),
Expand Down Expand Up @@ -124,14 +124,10 @@ std::vector<std::vector<int64_t>> GatherNextTokenInferShape(
const std::vector<int64_t>& encoder_batch_map_cpu_shape,
const std::vector<int64_t>& decoder_batch_map_cpu_shape,
const std::vector<int64_t>& len_info_cpu_shape,
const paddle::optional<std::vector<int64_t>>& output_padding_offset_shape) {
// if (output_padding_offset_shape) {
// PD_THROW("speculative decoding is not supported in XPU.");
// }
// int64_t bsz = cum_offsets_shape[0];
bool is_speculative) {
int64_t bsz = 0;
int64_t dim_embed = x_shape[1];
if (output_padding_offset_shape) {
if (is_speculative) {
return {{-1, dim_embed}};
} else {
return {{bsz, dim_embed}};
Expand All @@ -148,8 +144,7 @@ std::vector<paddle::DataType> GatherNextTokenInferDtype(
const paddle::DataType& decoder_seq_lod_cpu_dtype,
const paddle::DataType& encoder_batch_map_cpu_dtype,
const paddle::DataType& decoder_batch_map_cpu_dtype,
const paddle::DataType& len_info_cpu_dtype,
const paddle::optional<paddle::DataType>& output_padding_offset_dtype) {
const paddle::DataType& len_info_cpu_dtype) {
return {x_dtype};
}

Expand All @@ -163,10 +158,9 @@ PD_BUILD_STATIC_OP(gather_next_token)
"decoder_seq_lod_cpu",
"encoder_batch_map_cpu",
"decoder_batch_map_cpu",
"len_info_cpu",
paddle::Optional("output_padding_offset")})
"len_info_cpu"})
.Outputs({"out"})
.Attrs({"max_bsz: int"})
.Attrs({"is_speculative: bool", "max_bsz: int"})
.SetKernelFn(PD_KERNEL(GatherNextToken))
.SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype));
209 changes: 209 additions & 0 deletions custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// Verification kernel — outputs step_output_ids + step_output_len,
// and performs EOS / max_dec_len detection (read-only on step_idx).
// step_idx is NOT modified here; all state updates (including step_idx)
// are handled by unified_update_model_status.
//
// Verification strategies:
// 0 = TOPP : draft token in top-p candidate set (+ verify_window
// fallback) 1 = GREEDY : draft token == top-1 token (strict argmax
// match) 2 = TARGET_MATCH : draft token == target model's sampled token

#include <paddle/phi/backends/xpu/xpu_context.h>
#include "paddle/extension.h"
#include "xpu/plugin.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

namespace api = baidu::xpu::api;

// ============================================================
// Host function
// ============================================================
void VerifyDraftTokens(
// Core I/O
const paddle::Tensor &step_output_ids,
const paddle::Tensor &step_output_len,
const paddle::Tensor &step_input_ids,
// Target model outputs (optional, required for TARGET_MATCH)
const paddle::optional<paddle::Tensor> &target_tokens,
// Candidate set (optional, required for TOPP/GREEDY)
const paddle::optional<paddle::Tensor> &candidate_ids,
const paddle::optional<paddle::Tensor> &candidate_scores,
const paddle::optional<paddle::Tensor> &candidate_lens,
// Sampling params
const paddle::Tensor &topp,
// Metadata
const paddle::Tensor &stop_flags,
const paddle::Tensor &seq_lens_encoder,
const paddle::Tensor &seq_lens_this_time,
const paddle::Tensor &end_tokens,
const paddle::Tensor &is_block_step,
const paddle::Tensor &cu_seqlens_q_output,
const paddle::Tensor &reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection
const paddle::Tensor &max_dec_len,
const paddle::Tensor &step_idx,
int max_seq_len,
int verify_window,
int verify_strategy,
bool reject_all,
bool accept_all) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
api::Context *ctx =
static_cast<const phi::XPUContext *>(dev_ctx)->x_context();
bool xpu_ctx_flag = true;
if (step_output_ids.is_cpu()) {
ctx = new api::Context(api::kCPU);
xpu_ctx_flag = false;
}

auto bsz = step_output_ids.shape()[0];
auto real_bsz = seq_lens_this_time.shape()[0];
auto max_step_tokens = step_input_ids.shape()[1];
auto end_length = end_tokens.shape()[0];
// max_candidate_len: 1 if candidate_ids not provided, else from shape
int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔴 Bug xpu::ctx_guard RAII 作用域问题导致 use-after-free

RAII_GUARDif 块结束时析构,会释放通过 alloc 分配的 dev_curand_states_xpu 内存。但该指针在第143行被传递给 fastdeploy::plugin::verify_draft_tokens 使用,此时内存已被释放。

此外,当 xpu_ctx_flag=false 时,dev_curand_states_xpu 未初始化就被使用。

建议将 RAII_GUARD 的作用域扩展到整个函数,或使用独立的内存分配方式:

float *dev_curand_states_xpu = nullptr;
xpu::ctx_guard RAII_GUARD(ctx);  // 移到 if 块外部
if (xpu_ctx_flag) {
    dev_curand_states_xpu =
        RAII_GUARD.alloc<float>(dev_curand_states_cpu.size());
    xpu_memcpy(...);
} else {
    dev_curand_states_xpu = dev_curand_states_cpu.data();  // CPU 模式直接使用 CPU 指针
}


// curand state: only needed for TOPP(0) strategy (stochastic sampling)
int random_seed = 0;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 随机种子硬编码为 0,导致 TOPP 采样结果固定

当前实现中 random_seed = 0 且所有 batch 使用相同种子初始化 std::mt19937_64,这意味着:

  1. 每次调用生成的随机数序列相同
  2. 所有 batch 的 TOPP 采样行为一致

如果这是有意为之(如用于调试/复现),建议添加注释说明。否则建议:

  • 从外部传入随机种子
  • 或使用时间戳/请求 ID 作为种子来源

std::vector<int64_t> infer_seed(bsz, random_seed);
std::uniform_real_distribution<float> dist(0.0, 1.0);
std::vector<float> dev_curand_states_cpu;
for (int i = 0; i < bsz; i++) {
std::mt19937_64 engine(infer_seed[i]);
dev_curand_states_cpu.push_back(dist(engine));
}
float *dev_curand_states_xpu = dev_curand_states_cpu.data();
auto dev_curand_states_tensor =
paddle::empty({static_cast<int64_t>(dev_curand_states_cpu.size())},
paddle::DataType::FLOAT32,
step_output_ids.place());
if (xpu_ctx_flag) {
int h2d_ret =
xpu::do_host2device(ctx,
dev_curand_states_cpu.data(),
dev_curand_states_tensor.data<float>(),
dev_curand_states_cpu.size() * sizeof(float));
PD_CHECK(h2d_ret == 0, "do_host2device failed for curand states.");
dev_curand_states_xpu = dev_curand_states_tensor.data<float>();
}

Comment on lines +71 to +107
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

VerifyDraftTokens 中 dev_curand_states_xpu 仅在 xpu_ctx_flag==true 时分配;当 step_output_ids.is_cpu() 走 CPU ctx 路径时,该指针未初始化却仍传入 fastdeploy::plugin::verify_draft_tokens(同时 wrapper 侧还要求 curand_states 非空),会导致 CPU 路径不稳定/崩溃。建议:CPU 路径也分配/传入有效的 curand_states(或当 verify_strategy!=TOPP 时允许传 nullptr 并在 wrapper/kernel 侧放宽检查)。

Copilot uses AI. Check for mistakes.
// Get data pointers (nullptr if optional not provided)
const int64_t *target_tokens_ptr =
target_tokens ? target_tokens->data<int64_t>() : nullptr;
const int64_t *candidate_ids_ptr =
candidate_ids ? candidate_ids->data<int64_t>() : nullptr;
const float *candidate_scores_ptr =
candidate_scores ? candidate_scores->data<float>() : nullptr;
const int *candidate_lens_ptr =
candidate_lens ? candidate_lens->data<int>() : nullptr;

// Validate parameters based on verify_strategy.
// Note: empty_input_forward may lead to empty optional tensors — only
// validate when bsz > 0 (i.e. there are active sequences).
if (bsz > 0) {
if (verify_strategy == 0 /* TOPP */) {
if (!candidate_ids_ptr || !candidate_scores_ptr || !candidate_lens_ptr) {
PD_THROW(
"verify_strategy=TOPP (0) requires candidate_ids, "
"candidate_scores, candidate_lens");
}
} else if (verify_strategy == 1 /* GREEDY */) {
if (!target_tokens_ptr) {
PD_THROW("verify_strategy=GREEDY (1) requires target_tokens (argmax)");
}
} else if (verify_strategy == 2 /* TARGET_MATCH */) {
if (!target_tokens_ptr) {
PD_THROW(
"verify_strategy=TARGET_MATCH (2) requires target_tokens "
"(sampled)");
}
}
}
int ret = fastdeploy::plugin::verify_draft_tokens(
ctx,
// Core I/O
const_cast<int64_t *>(step_output_ids.data<int64_t>()),
const_cast<int *>(step_output_len.data<int>()),
step_input_ids.data<int64_t>(),
// Target model outputs
target_tokens_ptr,
// Candidate set
candidate_ids_ptr,
candidate_scores_ptr,
candidate_lens_ptr,
// Sampling params
dev_curand_states_xpu,
topp.data<float>(),
// Metadata
stop_flags.data<bool>(),
seq_lens_encoder.data<int>(),
seq_lens_this_time.data<int>(),
end_tokens.data<int64_t>(),
is_block_step.data<bool>(),
cu_seqlens_q_output.data<int>(),
reasoning_status.data<int>(),
// max_dec_len / step_idx
max_dec_len.data<int64_t>(),
step_idx.data<int64_t>(),
// Dimensions and config
bsz, // max_bsz
real_bsz, // real_bsz
max_step_tokens,
end_length,
max_seq_len,
max_candidate_len,
verify_window,
verify_strategy,
reject_all,
accept_all);
PD_CHECK(ret == 0, "verify_draft_tokens failed.");
if (step_output_ids.is_cpu()) {
delete ctx;
}
}

PD_BUILD_STATIC_OP(verify_draft_tokens)
.Inputs({"step_output_ids",
"step_output_len",
"step_input_ids",
paddle::Optional("target_tokens"),
paddle::Optional("candidate_ids"),
paddle::Optional("candidate_scores"),
paddle::Optional("candidate_lens"),
"topp",
"stop_flags",
"seq_lens_encoder",
"seq_lens_this_time",
"end_tokens",
"is_block_step",
"cu_seqlens_q_output",
"reasoning_status",
"max_dec_len",
"step_idx"})
.Outputs({"step_output_ids_out", "step_output_len_out"})
.Attrs({"max_seq_len: int",
"verify_window: int",
"verify_strategy: int",
"reject_all: bool",
"accept_all: bool"})
.SetInplaceMap({{"step_output_ids", "step_output_ids_out"},
{"step_output_len", "step_output_len_out"}})
.SetKernelFn(PD_KERNEL(VerifyDraftTokens));
60 changes: 58 additions & 2 deletions custom_ops/xpu_ops/src/ops/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ std::vector<paddle::Tensor> GatherNextToken(
const paddle::Tensor& encoder_batch_map_cpu,
const paddle::Tensor& decoder_batch_map_cpu,
const paddle::Tensor& len_info_cpu,
const paddle::optional<paddle::Tensor>& output_padding_offset,
bool is_speculative,
int max_bsz);

std::vector<paddle::Tensor> GetImgBoundaries(
Expand Down Expand Up @@ -702,6 +702,36 @@ std::vector<paddle::Tensor> WeightQuantize(const paddle::Tensor& x,
const int32_t arch,
const int32_t group_size);

void VerifyDraftTokens(
// Core I/O
const paddle::Tensor& step_output_ids,
const paddle::Tensor& step_output_len,
const paddle::Tensor& step_input_ids,
// Target model outputs (optional, required for TARGET_MATCH)
const paddle::optional<paddle::Tensor>& target_tokens,
// Candidate set (optional, required for TOPP/GREEDY)
const paddle::optional<paddle::Tensor>& candidate_ids,
const paddle::optional<paddle::Tensor>& candidate_scores,
const paddle::optional<paddle::Tensor>& candidate_lens,
// Sampling params
const paddle::Tensor& topp,
// Metadata
const paddle::Tensor& stop_flags,
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& end_tokens,
const paddle::Tensor& is_block_step,
const paddle::Tensor& cu_seqlens_q_output,
const paddle::Tensor& reasoning_status,
// max_dec_len / step_idx for EOS/max-len detection
const paddle::Tensor& max_dec_len,
const paddle::Tensor& step_idx,
int max_seq_len,
int verify_window,
int verify_strategy,
bool reject_all,
bool accept_all);

PYBIND11_MODULE(fastdeploy_ops, m) {
m.def("adjust_batch",
&AdjustBatch,
Expand Down Expand Up @@ -916,7 +946,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("encoder_batch_map_cpu"),
py::arg("decoder_batch_map_cpu"),
py::arg("len_info_cpu"),
py::arg("output_padding_offset"),
py::arg("is_speculative"),
py::arg("max_bsz"),
"Gather next token for XPU");

Expand Down Expand Up @@ -1045,6 +1075,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_draft_tokens"),
"Unified update model status");

m.def("verify_draft_tokens",
&VerifyDraftTokens,
py::arg("step_output_ids"),
py::arg("step_output_len"),
py::arg("step_input_ids"),
py::arg("target_tokens"),
py::arg("candidate_ids"),
py::arg("candidate_scores"),
py::arg("candidate_lens"),
py::arg("topp"),
py::arg("stop_flags"),
py::arg("seq_lens_encoder"),
py::arg("seq_lens_this_time"),
py::arg("end_tokens"),
py::arg("is_block_step"),
py::arg("cu_seqlens_q_output"),
py::arg("reasoning_status"),
py::arg("max_dec_len"),
py::arg("step_idx"),
py::arg("max_seq_len"),
py::arg("verify_window"),
py::arg("verify_strategy"),
py::arg("reject_all"),
py::arg("accept_all"),
"Perform speculative verification for decoding v2");

m.def("mtp_step_paddle",
&MTPStepPaddle,
py::arg("base_model_stop_flags"),
Expand Down
Loading
Loading