diff --git a/custom_ops/xpu_ops/src/ops/gather_next_token.cc b/custom_ops/xpu_ops/src/ops/gather_next_token.cc index 31c2142ca07..ee261965d6b 100644 --- a/custom_ops/xpu_ops/src/ops/gather_next_token.cc +++ b/custom_ops/xpu_ops/src/ops/gather_next_token.cc @@ -32,7 +32,7 @@ std::vector GatherNextToken( const paddle::Tensor& encoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& len_info_cpu, - const paddle::optional& output_padding_offset, + bool is_speculative, int max_bsz) { phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); @@ -73,7 +73,7 @@ std::vector GatherNextToken( const_cast(decoder_batch_map.data())}; paddle::Tensor out; - if (output_padding_offset) { + if (is_speculative) { int need_delete_token_num = 0; if (enc_batch > 0) { need_delete_token_num = @@ -88,7 +88,7 @@ std::vector GatherNextToken( return {out}; } - if (output_padding_offset) { + if (is_speculative) { int r = fastdeploy::plugin::eb_mtp_gather_next_token( ctx, reinterpret_cast(x.data()), @@ -124,14 +124,10 @@ std::vector> GatherNextTokenInferShape( const std::vector& encoder_batch_map_cpu_shape, const std::vector& decoder_batch_map_cpu_shape, const std::vector& len_info_cpu_shape, - const paddle::optional>& output_padding_offset_shape) { - // if (output_padding_offset_shape) { - // PD_THROW("speculative decoding is not supported in XPU."); - // } - // int64_t bsz = cum_offsets_shape[0]; + bool is_speculative) { int64_t bsz = 0; int64_t dim_embed = x_shape[1]; - if (output_padding_offset_shape) { + if (is_speculative) { return {{-1, dim_embed}}; } else { return {{bsz, dim_embed}}; @@ -148,8 +144,7 @@ std::vector GatherNextTokenInferDtype( const paddle::DataType& decoder_seq_lod_cpu_dtype, const paddle::DataType& encoder_batch_map_cpu_dtype, const paddle::DataType& decoder_batch_map_cpu_dtype, - const paddle::DataType& len_info_cpu_dtype, - const paddle::optional& output_padding_offset_dtype) { + const paddle::DataType& len_info_cpu_dtype) { return {x_dtype}; } @@ -163,10 +158,9 @@ PD_BUILD_STATIC_OP(gather_next_token) "decoder_seq_lod_cpu", "encoder_batch_map_cpu", "decoder_batch_map_cpu", - "len_info_cpu", - paddle::Optional("output_padding_offset")}) + "len_info_cpu"}) .Outputs({"out"}) - .Attrs({"max_bsz: int"}) + .Attrs({"is_speculative: bool", "max_bsz: int"}) .SetKernelFn(PD_KERNEL(GatherNextToken)) .SetInferShapeFn(PD_INFER_SHAPE(GatherNextTokenInferShape)) .SetInferDtypeFn(PD_INFER_DTYPE(GatherNextTokenInferDtype)); diff --git a/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc new file mode 100644 index 00000000000..fd3f76e8226 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc @@ -0,0 +1,209 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Verification kernel — outputs step_output_ids + step_output_len, +// and performs EOS / max_dec_len detection (read-only on step_idx). +// step_idx is NOT modified here; all state updates (including step_idx) +// are handled by unified_update_model_status. +// +// Verification strategies: +// 0 = TOPP : draft token in top-p candidate set (+ verify_window +// fallback) 1 = GREEDY : draft token == top-1 token (strict argmax +// match) 2 = TARGET_MATCH : draft token == target model's sampled token + +#include +#include "paddle/extension.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +namespace api = baidu::xpu::api; + +// ============================================================ +// Host function +// ============================================================ +void VerifyDraftTokens( + // Core I/O + const paddle::Tensor &step_output_ids, + const paddle::Tensor &step_output_len, + const paddle::Tensor &step_input_ids, + // Target model outputs (optional, required for TARGET_MATCH) + const paddle::optional &target_tokens, + // Candidate set (optional, required for TOPP/GREEDY) + const paddle::optional &candidate_ids, + const paddle::optional &candidate_scores, + const paddle::optional &candidate_lens, + // Sampling params + const paddle::Tensor &topp, + // Metadata + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &end_tokens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &cu_seqlens_q_output, + const paddle::Tensor &reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection + const paddle::Tensor &max_dec_len, + const paddle::Tensor &step_idx, + int max_seq_len, + int verify_window, + int verify_strategy, + bool reject_all, + bool accept_all) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + api::Context *ctx = + static_cast(dev_ctx)->x_context(); + bool xpu_ctx_flag = true; + if (step_output_ids.is_cpu()) { + ctx = new api::Context(api::kCPU); + xpu_ctx_flag = false; + } + + auto bsz = step_output_ids.shape()[0]; + auto real_bsz = seq_lens_this_time.shape()[0]; + auto max_step_tokens = step_input_ids.shape()[1]; + auto end_length = end_tokens.shape()[0]; + // max_candidate_len: 1 if candidate_ids not provided, else from shape + int max_candidate_len = candidate_ids ? candidate_ids->shape()[1] : 1; + + // curand state: only needed for TOPP(0) strategy (stochastic sampling) + int random_seed = 0; + std::vector infer_seed(bsz, random_seed); + std::uniform_real_distribution dist(0.0, 1.0); + std::vector dev_curand_states_cpu; + for (int i = 0; i < bsz; i++) { + std::mt19937_64 engine(infer_seed[i]); + dev_curand_states_cpu.push_back(dist(engine)); + } + float *dev_curand_states_xpu = dev_curand_states_cpu.data(); + auto dev_curand_states_tensor = + paddle::empty({static_cast(dev_curand_states_cpu.size())}, + paddle::DataType::FLOAT32, + step_output_ids.place()); + if (xpu_ctx_flag) { + int h2d_ret = + xpu::do_host2device(ctx, + dev_curand_states_cpu.data(), + dev_curand_states_tensor.data(), + dev_curand_states_cpu.size() * sizeof(float)); + PD_CHECK(h2d_ret == 0, "do_host2device failed for curand states."); + dev_curand_states_xpu = dev_curand_states_tensor.data(); + } + + // Get data pointers (nullptr if optional not provided) + const int64_t *target_tokens_ptr = + target_tokens ? target_tokens->data() : nullptr; + const int64_t *candidate_ids_ptr = + candidate_ids ? candidate_ids->data() : nullptr; + const float *candidate_scores_ptr = + candidate_scores ? candidate_scores->data() : nullptr; + const int *candidate_lens_ptr = + candidate_lens ? candidate_lens->data() : nullptr; + + // Validate parameters based on verify_strategy. + // Note: empty_input_forward may lead to empty optional tensors — only + // validate when bsz > 0 (i.e. there are active sequences). + if (bsz > 0) { + if (verify_strategy == 0 /* TOPP */) { + if (!candidate_ids_ptr || !candidate_scores_ptr || !candidate_lens_ptr) { + PD_THROW( + "verify_strategy=TOPP (0) requires candidate_ids, " + "candidate_scores, candidate_lens"); + } + } else if (verify_strategy == 1 /* GREEDY */) { + if (!target_tokens_ptr) { + PD_THROW("verify_strategy=GREEDY (1) requires target_tokens (argmax)"); + } + } else if (verify_strategy == 2 /* TARGET_MATCH */) { + if (!target_tokens_ptr) { + PD_THROW( + "verify_strategy=TARGET_MATCH (2) requires target_tokens " + "(sampled)"); + } + } + } + int ret = fastdeploy::plugin::verify_draft_tokens( + ctx, + // Core I/O + const_cast(step_output_ids.data()), + const_cast(step_output_len.data()), + step_input_ids.data(), + // Target model outputs + target_tokens_ptr, + // Candidate set + candidate_ids_ptr, + candidate_scores_ptr, + candidate_lens_ptr, + // Sampling params + dev_curand_states_xpu, + topp.data(), + // Metadata + stop_flags.data(), + seq_lens_encoder.data(), + seq_lens_this_time.data(), + end_tokens.data(), + is_block_step.data(), + cu_seqlens_q_output.data(), + reasoning_status.data(), + // max_dec_len / step_idx + max_dec_len.data(), + step_idx.data(), + // Dimensions and config + bsz, // max_bsz + real_bsz, // real_bsz + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); + PD_CHECK(ret == 0, "verify_draft_tokens failed."); + if (step_output_ids.is_cpu()) { + delete ctx; + } +} + +PD_BUILD_STATIC_OP(verify_draft_tokens) + .Inputs({"step_output_ids", + "step_output_len", + "step_input_ids", + paddle::Optional("target_tokens"), + paddle::Optional("candidate_ids"), + paddle::Optional("candidate_scores"), + paddle::Optional("candidate_lens"), + "topp", + "stop_flags", + "seq_lens_encoder", + "seq_lens_this_time", + "end_tokens", + "is_block_step", + "cu_seqlens_q_output", + "reasoning_status", + "max_dec_len", + "step_idx"}) + .Outputs({"step_output_ids_out", "step_output_len_out"}) + .Attrs({"max_seq_len: int", + "verify_window: int", + "verify_strategy: int", + "reject_all: bool", + "accept_all: bool"}) + .SetInplaceMap({{"step_output_ids", "step_output_ids_out"}, + {"step_output_len", "step_output_len_out"}}) + .SetKernelFn(PD_KERNEL(VerifyDraftTokens)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 01929009e5b..ebdd433c68f 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -421,7 +421,7 @@ std::vector GatherNextToken( const paddle::Tensor& encoder_batch_map_cpu, const paddle::Tensor& decoder_batch_map_cpu, const paddle::Tensor& len_info_cpu, - const paddle::optional& output_padding_offset, + bool is_speculative, int max_bsz); std::vector GetImgBoundaries( @@ -702,6 +702,36 @@ std::vector WeightQuantize(const paddle::Tensor& x, const int32_t arch, const int32_t group_size); +void VerifyDraftTokens( + // Core I/O + const paddle::Tensor& step_output_ids, + const paddle::Tensor& step_output_len, + const paddle::Tensor& step_input_ids, + // Target model outputs (optional, required for TARGET_MATCH) + const paddle::optional& target_tokens, + // Candidate set (optional, required for TOPP/GREEDY) + const paddle::optional& candidate_ids, + const paddle::optional& candidate_scores, + const paddle::optional& candidate_lens, + // Sampling params + const paddle::Tensor& topp, + // Metadata + const paddle::Tensor& stop_flags, + const paddle::Tensor& seq_lens_encoder, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& end_tokens, + const paddle::Tensor& is_block_step, + const paddle::Tensor& cu_seqlens_q_output, + const paddle::Tensor& reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection + const paddle::Tensor& max_dec_len, + const paddle::Tensor& step_idx, + int max_seq_len, + int verify_window, + int verify_strategy, + bool reject_all, + bool accept_all); + PYBIND11_MODULE(fastdeploy_ops, m) { m.def("adjust_batch", &AdjustBatch, @@ -916,7 +946,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("encoder_batch_map_cpu"), py::arg("decoder_batch_map_cpu"), py::arg("len_info_cpu"), - py::arg("output_padding_offset"), + py::arg("is_speculative"), py::arg("max_bsz"), "Gather next token for XPU"); @@ -1045,6 +1075,32 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("max_draft_tokens"), "Unified update model status"); + m.def("verify_draft_tokens", + &VerifyDraftTokens, + py::arg("step_output_ids"), + py::arg("step_output_len"), + py::arg("step_input_ids"), + py::arg("target_tokens"), + py::arg("candidate_ids"), + py::arg("candidate_scores"), + py::arg("candidate_lens"), + py::arg("topp"), + py::arg("stop_flags"), + py::arg("seq_lens_encoder"), + py::arg("seq_lens_this_time"), + py::arg("end_tokens"), + py::arg("is_block_step"), + py::arg("cu_seqlens_q_output"), + py::arg("reasoning_status"), + py::arg("max_dec_len"), + py::arg("step_idx"), + py::arg("max_seq_len"), + py::arg("verify_window"), + py::arg("verify_strategy"), + py::arg("reject_all"), + py::arg("accept_all"), + "Perform speculative verification for decoding v2"); + m.def("mtp_step_paddle", &MTPStepPaddle, py::arg("base_model_stop_flags"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 75bd67e23f0..7bb106f6620 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -767,6 +767,42 @@ DLL_EXPORT int speculate_limit_thinking_content_length_kernel( const int eos_token_id_len, const int inject_len, const bool splitwise_role_is_decode); +DLL_EXPORT int verify_draft_tokens( + api::Context* ctx, + // Core I/O + int64_t* step_output_ids, + int* step_output_len, + const int64_t* step_input_ids, + const int64_t* target_tokens, + // Candidate set for TOPP/GREEDY + const int64_t* candidate_ids, + const float* candidate_scores, + const int* candidate_lens, + // Sampling params + const float* curand_states, // nullptr for GREEDY/TARGET_MATCH + const float* topp, + // Metadata + const bool* stop_flags, + const int* seq_lens_encoder, + const int* seq_lens_this_time, + const int64_t* end_tokens, + const bool* is_block_step, + const int* cu_seqlens_q_output, + const int* reasoning_status, + // read-only + const int64_t* max_dec_len, + const int64_t* step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all); /*--------------------------------------- MTP end * --------------------------------------------*/ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu new file mode 100644 index 00000000000..066298e7028 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu @@ -0,0 +1,344 @@ +#include "xpu/kernel/cluster_debug.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/xtdk.h" +#include "xpu/kernel/xtdk_math.h" +#include "xpu/kernel/xtdk_simd.h" + +namespace fd_xpu3 { + +static inline __device__ int v_reduce(int32x16_t &v0, int32x16_t &v1) { + int res; + v1 = vvadd_int32x16(v0, v1); + auto v = vsrlp_int32x16(256, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(128, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(64, v1); + v1 = vvadd_int32x16(v, v1); + v = vsrlp_int32x16(32, v1); + v1 = vvadd_int32x16(v, v1); + res = vextract_int32x16(v1, 1); + return res; +} +static inline __device__ int ClusterReduce( + const _shared_ptr_ int *stop_flag_now_int_sm, int len) { + int sum = 0; + if (core_id() == 0) { + int32x16_t vec_x_0; + int32x16_t vec_x_1; + int32x16_t vec_y_0 = vzero(); + int32x16_t vec_y_1 = vzero(); + for (int i = 0; i < len; i += 32) { + vload2_sm(stop_flag_now_int_sm + i, vec_x_0, vec_x_1); + vec_y_0 = vvadd_int32x16(vec_y_0, vec_x_0); + vec_y_1 = vvadd_int32x16(vec_y_1, vec_x_1); + } + sum = v_reduce(vec_y_0, vec_y_1); + } + return sum; +} +__device__ bool is_in_end(const int64_t id, + __global_ptr__ const int64_t *end_ids, + int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} +__device__ inline bool is_in(__global_ptr__ const int64_t *candidates, + const int64_t draft, + const int candidate_len) { + for (int i = 0; i < candidate_len; i++) { + if (draft == candidates[i]) { + return true; + } + } + return false; +} + +static __device__ inline unsigned int xorwow(unsigned int &state) { + state ^= state >> 7; + state ^= state << 9; + state ^= state >> 13; + return state; +} + +__device__ int64_t +topp_sampling_kernel(__global_ptr__ const int64_t *candidate_ids, + __global_ptr__ const float *candidate_scores, + __global_ptr__ const float *dev_curand_states, + const int candidate_len, + const float topp) { + const int tid = core_id(); + float sum_scores = 0.0f; + float rand_top_p = *dev_curand_states * topp; + // printf("debug rand_top_p:%f\n",rand_top_p); + for (int i = 0; i < candidate_len; i++) { + sum_scores += candidate_scores[i]; + if (rand_top_p <= sum_scores) { + return candidate_ids[i]; + } + } + return candidate_ids[0]; +} + +// GREEDY / TARGET_MATCH: exact single-token match +__device__ inline bool verify_one_match(int64_t target_token, + int64_t draft_token) { + return target_token == draft_token; +} + +__device__ inline bool verify_one_topp( + __global_ptr__ const int64_t *verify_tokens_row, + int64_t draft_token, + int actual_cand_len) { + return is_in(verify_tokens_row, draft_token, actual_cand_len); +} + +// ============================================================ +// VerifyContext — per-batch mutable state + accept helpers. +// Eliminates repeated EOS/max_dec_len check and output write +// patterns across Phase 1 and Phase 2. +// ============================================================ +struct VerifyContext { + // Immutable per-batch (set once at kernel entry) + int bid; + int max_step_tokens; + int end_length; + __global_ptr__ const int64_t *end_tokens; + __global_ptr__ const int64_t *max_dec_len; + __global_ptr__ const int64_t *step_input_ids_now; + __global_ptr__ int64_t *step_output_ids; + + // Mutable per-batch state + int64_t cur_step_idx; + int output_len_now; + bool stopped; + + // Emit a token at position `pos` to output in Phase 1. + // Performs: step_idx check, EOS detection, token replacement, output write. + // Returns true if this sequence should stop (EOS or max_dec_len hit). + __device__ bool emit_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + if (is_eos || max_len_hit) { + stopped = true; + return true; + } + return false; + } + + // Emit the final token at position `pos` in Phase 2. + // Same EOS/max_dec_len logic. Increments output_len_now since + // Phase 2 produces one additional token. + __device__ void emit_final_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + } + + // TOPP-only: verify_window bulk-accept fallback. + // + // When draft token is NOT in top-p set but IS the top-2 token, + // check verify_window consecutive positions for top-1 match. + // If all match, bulk-accept from position i through ii. + // + // Returns the new loop position (i) after handling. + // Sets *rejected=true if fallback was not triggered (caller should break). + __device__ int try_verify_window_fallback( + int i, + bool *rejected, + __global_ptr__ const int64_t *verify_tokens_now, + int seq_len_this_time, + int max_candidate_len, + int verify_window) { + int ii = i; + if (max_candidate_len >= 2 && + verify_tokens_now[ii * max_candidate_len + 1] == + step_input_ids_now[ii + 1]) { + // top-2 matches — scan verify_window consecutive top-1 matches + int j = 0; + ii += 1; + for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) { + if (verify_tokens_now[ii * max_candidate_len] != + step_input_ids_now[ii + 1]) { + break; + } + } + if (j >= verify_window) { + // Bulk accept all tokens from i to ii + for (; i < ii; i++) { + if (emit_token(i, step_input_ids_now[i + 1])) return i; + } + return i; // continue outer loop from position ii + } + } + // Fallback not triggered or insufficient window — reject + *rejected = true; + return i; + } +}; + +__global__ void verify_draft_tokens( + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + const int64_t tid = core_id() * cluster_num() + cluster_id(); + const int64_t nthreads = cluster_num() * core_num(); + for (int64_t bid = tid; bid < real_bsz; bid += nthreads) { + step_output_len[bid] = 0; + if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) continue; + const int start_token_id = cu_seqlens_q_output[bid]; + // Pointers are strategy-dependent (may be nullptr for unused params) + auto *candidate_ids_now = + candidate_ids ? candidate_ids + start_token_id * max_candidate_len + : nullptr; + auto *candidate_scores_now = + candidate_scores ? candidate_scores + start_token_id * max_candidate_len + : nullptr; + auto *candidate_lens_now = + candidate_lens ? candidate_lens + start_token_id : nullptr; + auto *target_tokens_now = + target_tokens ? target_tokens + start_token_id : nullptr; + + // Initialize per-batch verification context + VerifyContext ctx; + ctx.bid = bid; + ctx.max_step_tokens = max_step_tokens; + ctx.end_length = end_length; + ctx.end_tokens = end_tokens; + ctx.max_dec_len = max_dec_len; + ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens; + ctx.step_output_ids = step_output_ids; + ctx.cur_step_idx = step_idx[bid]; + ctx.output_len_now = 0; + ctx.stopped = false; + + // ======== Phase 1: Verify draft tokens ======== + int i = 0; + for (; i < seq_lens_this_time[bid] - 1; i++) { + // Early exit conditions: reject-all, prefill, reasoning + if (reject_all || seq_lens_encoder[bid] != 0 || + reasoning_status[bid] == 1) { + break; + } + + // Accept-all override (debug/warmup) + if (accept_all) { + if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break; + continue; + } + + // Strategy dispatch + bool accepted = false; + switch (verify_strategy) { + case 0: { // TOPP + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len, + ctx.step_input_ids_now[i + 1], + actual_cand_len); + if (!accepted) { + bool rejected = false; + i = ctx.try_verify_window_fallback(i, + &rejected, + candidate_ids_now, + seq_lens_this_time[bid], + max_candidate_len, + verify_window); + if (ctx.stopped || rejected) goto phase1_done; + continue; // bulk accept succeeded, continue from new i + } + break; + } + case 1: // GREEDY + case 2: // TARGET_MATCH + accepted = verify_one_match(target_tokens_now[i], + ctx.step_input_ids_now[i + 1]); + break; + } + + if (accepted) { + if (ctx.emit_token(i, ctx.step_input_ids_now[i + 1])) break; + } else { + break; // reject + } + } + phase1_done: + + // ======== Phase 2: Output token for rejected/last position ======== + if (!ctx.stopped) { + int64_t output_token; + switch (verify_strategy) { + case 0: { // TOPP — stochastic sampling from candidate set + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + output_token = + topp_sampling_kernel(candidate_ids_now + i * max_candidate_len, + candidate_scores_now + i * max_candidate_len, + curand_states, + actual_cand_len, + topp[bid]); + break; + } + case 1: // GREEDY — deterministic argmax from target_tokens + case 2: // TARGET_MATCH — target model's sampled token + output_token = target_tokens_now[i]; + break; + } + ctx.emit_final_token(i, output_token); + } + step_output_len[bid] = ctx.output_len_now; + } +} + +} // namespace fd_xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp new file mode 100644 index 00000000000..b252b099dfd --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp @@ -0,0 +1,613 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace fd_xpu3 { +__attribute__((global)) void verify_draft_tokens( + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all); +} // namespace fd_xpu3 + +namespace fastdeploy { +namespace plugin { + +// ============================================================ +// Phase 1 helpers — single-step draft token verification +// ============================================================ + +// Check if draft_token appears in the candidate set +static inline bool is_in(const int64_t *candidates, + const int64_t draft, + const int candidate_len) { + for (int i = 0; i < candidate_len; i++) { + if (draft == candidates[i]) { + return true; + } + } + return false; +} +// TOPP: draft in top-p filtered candidate set +static inline bool verify_one_topp(const int64_t *verify_tokens_row, + int64_t draft_token, + int actual_cand_len) { + return is_in(verify_tokens_row, draft_token, actual_cand_len); +} + +// GREEDY / TARGET_MATCH: exact single-token match +static inline bool verify_one_match(int64_t target_token, int64_t draft_token) { + return target_token == draft_token; +} + +static inline bool is_in_end(const int64_t id, + const int64_t *end_ids, + int length) { + bool flag = false; + for (int i = 0; i < length; i++) { + if (id == end_ids[i]) { + return true; + } + } + return flag; +} + +// ============================================================ +// VerifyContext — per-batch mutable state + accept helpers. +// Eliminates repeated EOS/max_dec_len check and output write +// patterns across Phase 1 and Phase 2. +// ============================================================ +struct VerifyContext { + // Immutable per-batch (set once at kernel entry) + int bid; + int max_step_tokens; + int end_length; + const int64_t *end_tokens; + const int64_t *max_dec_len; + const int64_t *step_input_ids_now; + int64_t *step_output_ids; + + // Mutable per-batch state + int64_t cur_step_idx; + int output_len_now; + bool stopped; + + // Emit a token at position `pos` to output in Phase 1. + // Performs: step_idx check, EOS detection, token replacement, output write. + // Returns true if this sequence should stop (EOS or max_dec_len hit). + bool emit_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + if (is_eos || max_len_hit) { + stopped = true; + return true; + } + return false; + } + + // Emit the final token at position `pos` in Phase 2. + // Same EOS/max_dec_len logic. Increments output_len_now since + // Phase 2 produces one additional token. + void emit_final_token(int pos, int64_t token) { + cur_step_idx++; + bool is_eos = is_in_end(token, end_tokens, end_length); + bool max_len_hit = (cur_step_idx >= max_dec_len[bid]); + if ((is_eos || max_len_hit) && !is_eos) { + token = end_tokens[0]; + } + step_output_ids[bid * max_step_tokens + pos] = token; + output_len_now++; + } + + // TOPP-only: verify_window bulk-accept fallback. + // + // When draft token is NOT in top-p set but IS the top-2 token, + // check verify_window consecutive positions for top-1 match. + // If all match, bulk-accept from position i through ii. + // + // Returns the new loop position (i) after handling. + // Sets *rejected=true if fallback was not triggered (caller should break). + int try_verify_window_fallback(int i, + bool *rejected, + const int64_t *verify_tokens_now, + int seq_len_this_time, + int max_candidate_len, + int verify_window) { + int ii = i; + if (max_candidate_len >= 2 && + verify_tokens_now[ii * max_candidate_len + 1] == + step_input_ids_now[ii + 1]) { + // top-2 matches — scan verify_window consecutive top-1 matches + int j = 0; + ii += 1; + for (; j < verify_window && ii < seq_len_this_time - 1; j++, ii++) { + if (verify_tokens_now[ii * max_candidate_len] != + step_input_ids_now[ii + 1]) { + break; + } + } + if (j >= verify_window) { + // Bulk accept all tokens from i to ii + for (; i < ii; i++) { + if (emit_token(i, step_input_ids_now[i + 1])) return i; + } + return i; // continue outer loop from position ii + } + } + // Fallback not triggered or insufficient window — reject + *rejected = true; + return i; + } +}; + +static int64_t topp_sampling_kernel(const int64_t *candidate_ids, + const float *candidate_scores, + const float *dev_curand_states, + const int candidate_len, + const float topp, + int tid) { + // const int tid = core_id(); + float sum_scores = 0.0f; + float rand_top_p = *dev_curand_states * topp; + for (int i = 0; i < candidate_len; i++) { + // printf("debug cpu sample i:%d scores:%f,ids:%ld + // rand_top_p:%f,candidate_len:%d\n", + // i,candidate_scores[i],candidate_ids[i],rand_top_p,candidate_len); + sum_scores += candidate_scores[i]; + if (rand_top_p <= sum_scores) { + return candidate_ids[i]; + } + } + return candidate_ids[0]; +} + +static int cpu_wrapper( + api::Context *ctx, + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + for (int bid = 0; bid < max_bsz; bid++) { + step_output_len[bid] = 0; + + if (bid >= real_bsz || is_block_step[bid] || stop_flags[bid]) continue; + + const int start_token_id = cu_seqlens_q_output[bid]; + // Pointers are strategy-dependent (may be nullptr for unused params) + auto *candidate_ids_now = + candidate_ids ? candidate_ids + start_token_id * max_candidate_len + : nullptr; + auto *candidate_scores_now = + candidate_scores ? candidate_scores + start_token_id * max_candidate_len + : nullptr; + auto *candidate_lens_now = + candidate_lens ? candidate_lens + start_token_id : nullptr; + auto *target_tokens_now = + target_tokens ? target_tokens + start_token_id : nullptr; + + // Initialize per-batch verification context + VerifyContext v_ctx; + v_ctx.bid = bid; + v_ctx.max_step_tokens = max_step_tokens; + v_ctx.end_length = end_length; + v_ctx.end_tokens = end_tokens; + v_ctx.max_dec_len = max_dec_len; + v_ctx.step_input_ids_now = step_input_ids + bid * max_step_tokens; + v_ctx.step_output_ids = step_output_ids; + v_ctx.cur_step_idx = step_idx[bid]; + v_ctx.output_len_now = 0; + v_ctx.stopped = false; + + // ======== Phase 1: Verify draft tokens ======== + int i = 0; + for (; i < seq_lens_this_time[bid] - 1; i++) { + // Early exit conditions: reject-all, prefill, reasoning + if (reject_all || seq_lens_encoder[bid] != 0 || + reasoning_status[bid] == 1) { + break; + } + + // Accept-all override (debug/warmup) + if (accept_all) { + if (v_ctx.emit_token(i, v_ctx.step_input_ids_now[i + 1])) break; + continue; + } + + // Strategy dispatch + bool accepted = false; + switch (verify_strategy) { + case 0: { // TOPP + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + accepted = verify_one_topp(candidate_ids_now + i * max_candidate_len, + v_ctx.step_input_ids_now[i + 1], + actual_cand_len); + if (!accepted) { + bool rejected = false; + i = v_ctx.try_verify_window_fallback(i, + &rejected, + candidate_ids_now, + seq_lens_this_time[bid], + max_candidate_len, + verify_window); + if (v_ctx.stopped || rejected) goto phase1_done; + continue; // bulk accept succeeded, continue from new i + } + break; + } + case 1: // GREEDY + case 2: // TARGET_MATCH + accepted = verify_one_match(target_tokens_now[i], + v_ctx.step_input_ids_now[i + 1]); + break; + } + + if (accepted) { + if (v_ctx.emit_token(i, v_ctx.step_input_ids_now[i + 1])) break; + } else { + break; // reject + } + } + phase1_done: + + // ======== Phase 2: Output token for rejected/last position ======== + if (!v_ctx.stopped) { + int64_t output_token = 0; + switch (verify_strategy) { + case 0: { // TOPP — stochastic sampling from candidate set + auto actual_cand_len = candidate_lens_now[i] > max_candidate_len + ? max_candidate_len + : candidate_lens_now[i]; + output_token = + topp_sampling_kernel(candidate_ids_now + i * max_candidate_len, + candidate_scores_now + i * max_candidate_len, + curand_states + i, + actual_cand_len, + topp[bid], + bid); + break; + } + case 1: // GREEDY — deterministic argmax from target_tokens + case 2: // TARGET_MATCH — target model's sampled token + output_token = target_tokens_now[i]; + break; + } + v_ctx.emit_final_token(i, output_token); + } + step_output_len[bid] = v_ctx.output_len_now; + } + + return api::SUCCESS; +} + +static int xpu3_wrapper( + api::Context *ctx, + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + using XPU_INT64 = typename api::XPUIndexType::type; + int32_t ret_xre = + fd_xpu3::verify_draft_tokens<<ncluster(), 64, ctx->xpu_stream>>>( + reinterpret_cast(step_output_ids), + step_output_len, + reinterpret_cast(step_input_ids), + reinterpret_cast(target_tokens), + reinterpret_cast(candidate_ids), + candidate_scores, + candidate_lens, + curand_states, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + reinterpret_cast(end_tokens), + is_block_step, + cu_seqlens_q_output, + reasoning_status, + reinterpret_cast(max_dec_len), + reinterpret_cast(step_idx), + max_bsz, + real_bsz, + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); + KERNEL_ASSERT_SUCCESS(ctx, ret_xre); + return api::SUCCESS; +} + +int verify_draft_tokens( + api::Context *ctx, + // Core I/O + int64_t *step_output_ids, + int *step_output_len, + const int64_t *step_input_ids, // draft tokens + // Target model outputs (strategy-dependent interpretation) + const int64_t + *target_tokens, // GREEDY:argmax, TARGET_MATCH:sampled, TOPP:unused + // Candidate set for TOPP/GREEDY (TARGET_MATCH: unused) + const int64_t *candidate_ids, + const float *candidate_scores, + const int *candidate_lens, + // Sampling params + const float *curand_states, // nullptr for GREEDY/TARGET_MATCH + const float *topp, + // Metadata + const bool *stop_flags, + const int *seq_lens_encoder, + const int *seq_lens_this_time, + const int64_t *end_tokens, + const bool *is_block_step, + const int *cu_seqlens_q_output, + const int *reasoning_status, + // max_dec_len / step_idx for EOS/max-len detection (read-only) + const int64_t *max_dec_len, + const int64_t *step_idx, + // Dimensions and config + const int max_bsz, + const int real_bsz, + const int max_step_tokens, + const int end_length, + const int max_seq_len, + const int max_candidate_len, + const int verify_window, + const int verify_strategy, // 0=TOPP, 1=GREEDY, 2=TARGET_MATCH + const bool reject_all, + const bool accept_all) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "verify_draft_tokens", int64_t); + WRAPPER_DUMP_PARAM6(ctx, + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores); + WRAPPER_DUMP_PARAM6(ctx, + candidate_lens, + curand_states, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time); + + WRAPPER_DUMP_PARAM6(ctx, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx); + + WRAPPER_DUMP_PARAM6(ctx, + max_bsz, + real_bsz, + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len); + + WRAPPER_DUMP_PARAM4( + ctx, verify_window, verify_strategy, reject_all, accept_all); + WRAPPER_DUMP(ctx); + + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_step_tokens, step_output_ids); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz * max_step_tokens, step_input_ids); + // len(target_tokens) = cu_seqlens_q_output[-1] + WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, target_tokens); + WRAPPER_CHECK_PTR_OR_NULL(ctx, int64_t, real_bsz, candidate_lens); + WRAPPER_CHECK_PTR_OR_NULL( + ctx, int64_t, real_bsz * max_candidate_len, candidate_ids); + WRAPPER_CHECK_PTR_OR_NULL( + ctx, float, real_bsz *max_candidate_len, candidate_scores); + + WRAPPER_CHECK_PTR(ctx, float, real_bsz, curand_states); + WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp); + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens); + + WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step); + WRAPPER_CHECK_PTR(ctx, int, real_bsz + 1, cu_seqlens_q_output); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, reasoning_status); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, max_dec_len); + WRAPPER_CHECK_PTR(ctx, int64_t, real_bsz, step_idx); + // param check sm size limit + WRAPPER_ASSERT_GT(ctx, real_bsz, 0); + WRAPPER_ASSERT_LE(ctx, real_bsz, 1024); + WRAPPER_ASSERT_LE(ctx, real_bsz * max_candidate_len, 2048); + WRAPPER_ASSERT_LE(ctx, verify_window * max_candidate_len, 128); + + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(ctx, + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + + curand_states, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx, + max_bsz, + real_bsz, + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + curand_states, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx, + max_bsz, + real_bsz, + max_step_tokens, + end_length, + max_seq_len, + max_candidate_len, + verify_window, + verify_strategy, + reject_all, + accept_all); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace fastdeploy diff --git a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py index 758dff17e58..bc074242b4e 100644 --- a/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py +++ b/custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py @@ -24,7 +24,7 @@ ) -def _run_test_base(seq_lens_this_time_data, output_padding_offset): +def _run_test_base(seq_lens_this_time_data, is_speculative): """ 通用的基础测试执行函数,包含了两个场景共有的逻辑。 """ @@ -120,7 +120,7 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset): encoder_batch_map_cpu, decoder_batch_map_cpu, len_info_cpu, - output_padding_offset, + is_speculative, -1, ) @@ -136,14 +136,14 @@ def _run_test_base(seq_lens_this_time_data, output_padding_offset): encoder_batch_map_cpu, decoder_batch_map_cpu, len_info_cpu, - output_padding_offset, + is_speculative, -1, ) gather_out_np = gather_out.astype("float32").cpu().numpy() gather_out_cpu_np = gather_out_cpu.astype("float32").cpu().numpy() - if output_padding_offset is not None: + if is_speculative: np.testing.assert_allclose(gather_out_np, gather_out_cpu_np, err_msg="gather_next_token check failed!") else: for i in range(gather_out_cpu.shape[0]): @@ -160,19 +160,14 @@ def test_mix_with_mtp(self): """测试混合批次处理中的 MTP (Multi-Token Prediction) 场景""" print("\nRunning test: test_mix_with_mtp") seq_lens_this_time_data = [100, 2, 0, 1, 120, 140, 3] - bsz = len(seq_lens_this_time_data) - output_padding_offset = paddle.zeros(bsz, dtype="int32") - - _run_test_base(seq_lens_this_time_data, output_padding_offset) + _run_test_base(seq_lens_this_time_data, True) print("Test passed for scenario: With MTP") def test_mix_without_mtp(self): """测试非 MTP (Single-Token Prediction) 场景下的功能""" print("\nRunning test: test_mix_without_mtp") seq_lens_this_time_data = [100, 1, 0, 1, 120, 140, 1] - output_padding_offset = None # 非 MTP 场景下,此参数为 None - - _run_test_base(seq_lens_this_time_data, output_padding_offset) + _run_test_base(seq_lens_this_time_data, False) print("Test passed for scenario: Without MTP") diff --git a/custom_ops/xpu_ops/test/test_verify_draft_tokens.py b/custom_ops/xpu_ops/test/test_verify_draft_tokens.py new file mode 100644 index 00000000000..cfe6a6214f6 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_verify_draft_tokens.py @@ -0,0 +1,1039 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Unit tests for verify_draft_tokens kernel. + +Verification strategies: +- TOPP (0): Verify draft token is in top-p candidate set +- GREEDY (1): Verify draft token matches target model's argmax +- TARGET_MATCH (2): Verify draft token matches target model's sampled token +""" + +import random +import unittest +from typing import Any, Dict + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import verify_draft_tokens +from fastdeploy.spec_decode import VerifyStrategy + +CPU_PLACE = paddle.CPUPlace() +CUDA_PLACE = paddle.XPUPlace(0) if paddle.is_compiled_with_xpu() else paddle.CPUPlace() + + +# ============================================================ +# Helpers: tensor creation / kernel invocation / comparison +# ============================================================ + + +def to_paddle_inputs(inputs: Dict[str, Any]) -> Dict[str, Any]: + """Convert numpy input dict to paddle tensors on GPU.""" + paddle_inputs = {} + for k, v in inputs.items(): + if isinstance(v, (int, bool, float, str)): + paddle_inputs[k] = v + elif v is not None: + paddle_inputs[k] = paddle.to_tensor(v, place=CUDA_PLACE) + else: + paddle_inputs[k] = None + return paddle_inputs + + +def run_kernel(paddle_inputs: Dict[str, Any], inputs: Dict[str, Any]): + """Call verify_draft_tokens kernel.""" + verify_draft_tokens( + paddle_inputs["step_output_ids"], + paddle_inputs["step_output_len"], + paddle_inputs["step_input_ids"], + paddle_inputs["target_tokens"], + paddle_inputs["candidate_ids"], + paddle_inputs["candidate_scores"], + paddle_inputs["candidate_lens"], + paddle_inputs["topp"], + paddle_inputs["stop_flags"], + paddle_inputs["seq_lens_encoder"], + paddle_inputs["seq_lens_this_time"], + paddle_inputs["end_tokens"], + paddle_inputs["is_block_step"], + paddle_inputs["cu_seqlens_q_output"], + paddle_inputs["reasoning_status"], + paddle_inputs["max_dec_len"], + paddle_inputs["step_idx"], + inputs["max_seq_len"], + inputs["verify_window"], + inputs["verify_strategy"], + inputs["reject_all"], + inputs["accept_all"], + ) + + +def run_ref(inputs: Dict[str, Any]): + """Run reference implementation on deep-copied inputs, return (output_ids, output_len).""" + ref = {k: v.copy() if isinstance(v, np.ndarray) else v for k, v in inputs.items()} + return verify_draft_tokens_ref( + ref["step_output_ids"], + ref["step_output_len"], + ref["step_input_ids"], + ref["target_tokens"], + ref["candidate_ids"], + ref["candidate_scores"], + ref["candidate_lens"], + ref["topp"], + ref["stop_flags"], + ref["seq_lens_encoder"], + ref["seq_lens_this_time"], + ref["end_tokens"], + ref["is_block_step"], + ref["cu_seqlens_q_output"], + ref["reasoning_status"], + ref["max_dec_len"], + ref["step_idx"], + ref["max_seq_len"], + ref["verify_window"], + ref["verify_strategy"], + ref["reject_all"], + ref["accept_all"], + ) + + +def compare_results( + paddle_inputs: Dict[str, Any], + step_output_ids_ref: np.ndarray, + step_output_len_ref: np.ndarray, + inputs: Dict[str, Any], + label: str = "unknown", +): + """Compare GPU kernel output vs reference.""" + gpu_ids = paddle_inputs["step_output_ids"].numpy() + gpu_len = paddle_inputs["step_output_len"].numpy() + np.testing.assert_array_equal( + gpu_len, + step_output_len_ref, + err_msg=f"step_output_len mismatch ({label})", + ) + + if inputs["verify_strategy"] == 0: # TOPP — Phase 2 is stochastic + real_bsz = inputs["seq_lens_this_time"].shape[0] + for bid in range(real_bsz): + ref_len = int(step_output_len_ref[bid]) + if ref_len > 1: + print(gpu_ids[bid, : ref_len - 1], step_output_ids_ref[bid, : ref_len - 1]) + np.testing.assert_array_equal( + gpu_ids[bid, : ref_len - 1], + step_output_ids_ref[bid, : ref_len - 1], + err_msg=f"step_output_ids (accepted) mismatch at bid={bid} ({label})", + ) + else: + np.testing.assert_array_equal( + gpu_ids, + step_output_ids_ref, + err_msg=f"step_output_ids mismatch ({label})", + ) + + +# ============================================================ +# Reference helpers +# ============================================================ + + +def topp_sampling_kernel(candidate_ids, candidate_scores, curand_value, candidate_len, topp, tid=0): + rand_top_p = curand_value * topp + sum_scores = 0.0 + for i in range(candidate_len): + sum_scores += candidate_scores[i] + if rand_top_p <= sum_scores: + return int(candidate_ids[i]) + return int(candidate_ids[0]) + + +def is_in_end(token, end_tokens, end_length): + return token in end_tokens[:end_length] + + +def is_in(candidate_list, token, length): + return token in candidate_list[:length] + + +class _VerifyContext: + """Python mirror of the CUDA VerifyContext struct for reference testing.""" + + def __init__( + self, + bid, + max_step_tokens, + end_length, + end_tokens, + max_dec_len, + step_input_ids_now, + step_output_ids_flat, + cur_step_idx, + ): + self.bid = bid + self.max_step_tokens = max_step_tokens + self.end_length = end_length + self.end_tokens = end_tokens + self.max_dec_len = max_dec_len + self.step_input_ids_now = step_input_ids_now + self.step_output_ids_flat = step_output_ids_flat + self.cur_step_idx = cur_step_idx + self.output_len_now = 0 + self.stopped = False + + def emit_token(self, pos, token): + """Emit a token to output. Returns True if sequence should stop.""" + self.cur_step_idx += 1 + eos = is_in_end(token, self.end_tokens, self.end_length) + max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid]) + if (eos or max_hit) and not eos: + token = int(self.end_tokens[0]) + self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token + self.output_len_now += 1 + if eos or max_hit: + self.stopped = True + return True + return False + + def emit_final_token(self, pos, token): + """Emit the Phase 2 final token. Increments output_len_now.""" + self.cur_step_idx += 1 + eos = is_in_end(token, self.end_tokens, self.end_length) + max_hit = self.cur_step_idx >= int(self.max_dec_len[self.bid]) + if (eos or max_hit) and not eos: + token = int(self.end_tokens[0]) + self.step_output_ids_flat[self.bid * self.max_step_tokens + pos] = token + self.output_len_now += 1 + + +def verify_draft_tokens_ref( + step_output_ids, + step_output_len, + step_input_ids, + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + topp, + stop_flags, + seq_lens_encoder, + seq_lens_this_time, + end_tokens, + is_block_step, + cu_seqlens_q_output, + reasoning_status, + max_dec_len, + step_idx, + max_seq_len, + verify_window, + verify_strategy, + reject_all, + accept_all, +): + """Reference implementation of verify_draft_tokens in Python.""" + real_bsz = seq_lens_this_time.shape[0] + max_step_tokens = step_input_ids.shape[1] + end_length = end_tokens.shape[0] + max_candidate_len = candidate_ids.shape[1] if candidate_ids is not None else 1 + + dev_curand_states = [random.Random(0).random() for _ in range(max_step_tokens)] + + step_output_ids_flat = step_output_ids.reshape(-1) + step_input_ids_flat = step_input_ids.reshape(-1) + candidate_ids_flat = candidate_ids.reshape(-1) if candidate_ids is not None else None + candidate_scores_flat = candidate_scores.reshape(-1) if candidate_scores is not None else None + + for bid in range(real_bsz): + start_token_id = cu_seqlens_q_output[bid] + + if is_block_step[bid] or stop_flags[bid]: + step_output_len[bid] = 0 + continue + + step_input_ids_now = step_input_ids_flat[bid * max_step_tokens :] + target_tokens_now = target_tokens[start_token_id:] if target_tokens is not None else None + candidate_ids_now = ( + candidate_ids_flat[start_token_id * max_candidate_len :] if candidate_ids_flat is not None else None + ) + candidate_lens_now = candidate_lens[start_token_id:] if candidate_lens is not None else None + candidate_scores_now = ( + candidate_scores_flat[start_token_id * max_candidate_len :] if candidate_scores_flat is not None else None + ) + + ctx = _VerifyContext( + bid, + max_step_tokens, + end_length, + end_tokens, + max_dec_len, + step_input_ids_now, + step_output_ids_flat, + int(step_idx[bid]), + ) + + # Phase 1: Verify + i = 0 + while i < seq_lens_this_time[bid] - 1: + if reject_all or seq_lens_encoder[bid] != 0 or reasoning_status[bid] == 1: + break + if accept_all: + if ctx.emit_token(i, step_input_ids_now[i + 1]): + break + i += 1 + continue + + accepted = False + if verify_strategy == 0: # TOPP + actual_cand_len = min(candidate_lens_now[i], max_candidate_len) + accepted = is_in( + candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len], + step_input_ids_now[i + 1], + actual_cand_len, + ) + if not accepted: + # verify_window fallback + ii = i + if ( + max_candidate_len >= 2 + and candidate_ids_now[ii * max_candidate_len + 1] == step_input_ids_now[ii + 1] + ): + j, ii = 0, ii + 1 + while j < verify_window and ii < seq_lens_this_time[bid] - 1: + if candidate_ids_now[ii * max_candidate_len] != step_input_ids_now[ii + 1]: + break + j += 1 + ii += 1 + if j >= verify_window: + for k in range(i, ii): + if ctx.emit_token(k, step_input_ids_now[k + 1]): + i = k + break + if ctx.stopped: + break + i = ii + continue + break + elif verify_strategy in (1, 2): # GREEDY / TARGET_MATCH + accepted = target_tokens_now[i] == step_input_ids_now[i + 1] + + if accepted: + if ctx.emit_token(i, step_input_ids_now[i + 1]): + break + else: + break + i += 1 + + # Phase 2: Sample for rejected/last position + if not ctx.stopped: + if verify_strategy == 0: + if candidate_lens_now is not None and len(candidate_lens_now) > i: + actual_cand_len = min(candidate_lens_now[i], max_candidate_len) + accept_token = topp_sampling_kernel( + candidate_ids_now[i * max_candidate_len : (i + 1) * max_candidate_len], + candidate_scores_now[i * max_candidate_len : (i + 1) * max_candidate_len], + dev_curand_states[i], + actual_cand_len, + topp[bid], + ) + else: + accept_token = int(step_input_ids_now[0]) + elif verify_strategy in (1, 2): + accept_token = ( + int(target_tokens_now[i]) + if target_tokens_now is not None and len(target_tokens_now) > i + else int(step_input_ids_now[0]) + ) + else: + accept_token = ( + int(candidate_ids_now[i * max_candidate_len]) + if candidate_ids_now is not None + else int(step_input_ids_now[0]) + ) + ctx.emit_final_token(i, accept_token) + + step_output_len[bid] = ctx.output_len_now + + return step_output_ids, step_output_len + + +# ============================================================ +# Input generation +# ============================================================ + + +def gen_verify_draft_tokens_inputs( + real_bsz: int = 32, + max_draft_tokens: int = 16, + max_seq_len: int = 256, + max_candidate_len: int = 8, + verify_window: int = 2, + end_length: int = 4, + verify_strategy: int = 1, + reject_all: bool = False, + accept_all: bool = False, + match_ratio: float = 0.0, + seed: int = 2025, +) -> Dict[str, Any]: + """Generate test inputs for verify_draft_tokens kernel. + + Args: + match_ratio: Fraction of draft token positions where target/candidates + are forced to match step_input_ids, so the acceptance path is exercised. + 0.0 = fully random (mostly rejects), 1.0 = all positions match. + """ + rng = np.random.default_rng(seed) + + seq_lens_encoder = np.zeros(real_bsz, dtype=np.int32) + seq_lens_this_time = rng.integers(1, max_draft_tokens + 1, size=real_bsz, dtype=np.int32) + step_input_ids = rng.integers(0, 1000, size=(real_bsz, max_draft_tokens), dtype=np.int64) + + sum_seq = int(np.sum(seq_lens_this_time)) + + if verify_strategy in (1, 2): # GREEDY / TARGET_MATCH + target_tokens = rng.integers(0, 1000, size=(sum_seq,), dtype=np.int64) + candidate_ids = None + candidate_scores = None + candidate_lens = None + else: # TOPP + target_tokens = None + candidate_ids = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + candidate_scores = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + candidate_scores = candidate_scores / candidate_scores.sum(axis=1, keepdims=True) + candidate_lens = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) + + end_tokens = rng.integers(1, 1000, size=end_length, dtype=np.int64) + is_block_step = rng.integers(0, 2, size=real_bsz, dtype=bool) + + cu_seqlens_q_output = np.zeros(real_bsz + 1, dtype=np.int32) + for i in range(real_bsz): + cu_seqlens_q_output[i + 1] = cu_seqlens_q_output[i] + seq_lens_this_time[i] + cu_seqlens_q_output = cu_seqlens_q_output[:real_bsz].astype(np.int32) + + topp = rng.uniform(0.8, 1.0, size=real_bsz).astype(np.float32) + reasoning_status = np.zeros(real_bsz, dtype=np.int32) + step_output_ids = np.zeros((real_bsz, max_draft_tokens), dtype=np.int64) + step_output_len = np.zeros(real_bsz, dtype=np.int32) + stop_flags = np.zeros(real_bsz, dtype=bool) + + # Force match_ratio fraction of positions so acceptance path is tested + if match_ratio > 0.0: + offset = 0 + for bid in range(real_bsz): + slt = int(seq_lens_this_time[bid]) + n_match = max(1, int((slt - 1) * match_ratio)) # slt-1 verify positions + for pos in range(min(n_match, slt - 1)): + draft_token = int(step_input_ids[bid, pos + 1]) + # Ensure draft_token is not an end_token (would cause early stop) + while draft_token in end_tokens[:end_length]: + draft_token = (draft_token + 1) % 1000 + step_input_ids[bid, pos + 1] = draft_token + if verify_strategy in (1, 2) and target_tokens is not None: + target_tokens[offset + pos] = draft_token + elif verify_strategy == 0 and candidate_ids is not None: + candidate_ids[offset + pos, 0] = draft_token + candidate_lens[offset + pos] = max(candidate_lens[offset + pos], 1) + offset += slt + + return { + "step_output_ids": step_output_ids, + "step_output_len": step_output_len, + "step_input_ids": step_input_ids, + "target_tokens": target_tokens, + "candidate_ids": candidate_ids, + "candidate_scores": candidate_scores, + "candidate_lens": candidate_lens, + "topp": topp, + "stop_flags": stop_flags, + "seq_lens_encoder": seq_lens_encoder, + "seq_lens_this_time": seq_lens_this_time, + "end_tokens": end_tokens, + "is_block_step": is_block_step, + "cu_seqlens_q_output": cu_seqlens_q_output, + "reasoning_status": reasoning_status, + "max_dec_len": rng.integers(50, 200, size=real_bsz, dtype=np.int64), + "step_idx": rng.integers(0, 30, size=real_bsz, dtype=np.int64), + "max_seq_len": max_seq_len, + "verify_window": verify_window, + "verify_strategy": verify_strategy, + "reject_all": reject_all, + "accept_all": accept_all, + } + + +# ============================================================ +# Test configs +# ============================================================ + +TEST_CONFIGS = [ + # --- strategy coverage (random, mostly rejects) --- + { + "name": "greedy_small_batch", + "real_bsz": 1, + "max_draft_tokens": 9, + "max_seq_len": 11, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 5, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + { + "name": "greedy_medium_batch", + "real_bsz": 33, + "max_draft_tokens": 5, + "max_seq_len": 10111, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 6, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + { + "name": "topp_small_batch", + "real_bsz": 6, + "max_draft_tokens": 4, + "max_seq_len": 10001, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 7, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + }, + { + "name": "target_match_medium", + "real_bsz": 7, + "max_draft_tokens": 3, + "max_seq_len": 777, + "max_candidate_len": 7, + "verify_window": 2, + "end_length": 5, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + }, + { + "name": "greedy_large_batch", + "real_bsz": 55, + "max_draft_tokens": 5, + "max_seq_len": 31, + "max_candidate_len": 9, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, + # --- partial acceptance (match_ratio forces draft tokens to match target/candidates) --- + { + "name": "greedy_half_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "match_ratio": 0.5, + }, + { + "name": "greedy_full_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "match_ratio": 1.0, + }, + { + "name": "topp_half_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "match_ratio": 0.5, + }, + { + "name": "topp_full_accept", + "real_bsz": 8, + "max_draft_tokens": 8, + "max_seq_len": 256, + "max_candidate_len": 6, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "match_ratio": 1.0, + }, + { + "name": "target_match_accept", + "real_bsz": 8, + "max_draft_tokens": 6, + "max_seq_len": 256, + "max_candidate_len": 4, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + "match_ratio": 0.7, + }, + # --- reject_all / accept_all (kernel-level flags) --- + { + "name": "reject_all", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "reject_all": True, + }, + { + "name": "accept_all", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "accept_all": True, + }, + { + "name": "reject_all_topp", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TOPP.value, + "seed": 42, + "reject_all": True, + }, + { + "name": "reject_all_target_match", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + "reject_all": True, + }, + { + "name": "accept_all_greedy", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + "accept_all": True, + }, + { + "name": "accept_all_target_match", + "real_bsz": 8, + "max_draft_tokens": 5, + "max_seq_len": 100, + "max_candidate_len": 5, + "verify_window": 2, + "end_length": 3, + "verify_strategy": VerifyStrategy.TARGET_MATCH.value, + "seed": 42, + "accept_all": True, + }, + # --- edge cases --- + { + "name": "empty_batch", + "real_bsz": 1, + "max_draft_tokens": 1, + "max_seq_len": 10, + "max_candidate_len": 2, + "verify_window": 1, + "end_length": 4, + "verify_strategy": VerifyStrategy.GREEDY.value, + "seed": 42, + }, +] + + +# ============================================================ +# Test suite +# ============================================================ + + +class TestVerifyDraftTokens(unittest.TestCase): + + def setUp(self): + pass + # if not paddle.is_compiled_with_cuda(): + # self.skipTest("Requires CUDA") + + # ------ shared run + check helper ------ + + def _run_and_compare(self, inputs: Dict[str, Any], label: str = ""): + """Convert→run kernel→run ref→compare.""" + paddle_inputs = to_paddle_inputs(inputs) + # print("paddle_inputs: ", paddle_inputs) + run_kernel(paddle_inputs, inputs) + ids_ref, len_ref = run_ref(inputs) + compare_results(paddle_inputs, ids_ref, len_ref, inputs, label) + return paddle_inputs + + # ------ test cases ------ + + def test_verify_configs(self): + """Test all configs in TEST_CONFIGS (strategies, reject/accept, edge cases).""" + for cfg in TEST_CONFIGS: + with self.subTest(name=cfg["name"]): + test_cfg = {k: v for k, v in cfg.items() if k != "name"} + inputs = gen_verify_draft_tokens_inputs(**test_cfg) + self._run_and_compare(inputs, label=cfg["name"]) + + def test_eos_handling(self): + """Test EOS token in draft triggers early stop.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42 + ) + inputs["step_input_ids"][0, 2] = inputs["end_tokens"][0] + self._run_and_compare(inputs, label="eos_handling") + + def test_max_dec_len_truncation(self): + """Test max_dec_len causes token replacement with end_tokens[0].""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + # Set step_idx close to max_dec_len so it triggers during verification + inputs["step_idx"][:] = [48, 10, 10, 10] + inputs["max_dec_len"][:] = [50, 200, 200, 200] + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Ensure no accidental EOS in draft tokens + for bid in range(4): + for j in range(5): + while inputs["step_input_ids"][bid, j] in inputs["end_tokens"]: + inputs["step_input_ids"][bid, j] = (inputs["step_input_ids"][bid, j] + 1) % 1000 + self._run_and_compare(inputs, label="max_dec_len_truncation") + + def test_verify_strategy_enum(self): + self.assertEqual(VerifyStrategy.TOPP.value, 0) + self.assertEqual(VerifyStrategy.GREEDY.value, 1) + self.assertEqual(VerifyStrategy.TARGET_MATCH.value, 2) + + def test_verify_strategy_from_string(self): + self.assertEqual(VerifyStrategy.from_string("topp"), VerifyStrategy.TOPP) + self.assertEqual(VerifyStrategy.from_string("TOPP"), VerifyStrategy.TOPP) + self.assertEqual(VerifyStrategy.from_string("greedy"), VerifyStrategy.GREEDY) + self.assertEqual(VerifyStrategy.from_string("target_match"), VerifyStrategy.TARGET_MATCH) + with self.assertRaises(ValueError): + VerifyStrategy.from_string("invalid") + + def test_topp_verify_window_fallback(self): + """Test TOPP verify_window fallback: top-2 match + consecutive top-1 matches.""" + real_bsz, max_draft_tokens, max_candidate_len, verify_window = 1, 8, 4, 2 + + inputs = gen_verify_draft_tokens_inputs( + real_bsz=real_bsz, + max_draft_tokens=max_draft_tokens, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=max_candidate_len, + verify_window=verify_window, + seed=42, + ) + + # Rebuild arrays for full seq_lens_this_time + new_slt = max_draft_tokens + 1 + inputs["seq_lens_this_time"] = np.array([new_slt], dtype=np.int32) + inputs["cu_seqlens_q_output"] = np.array([0], dtype=np.int32) + + rng = np.random.default_rng(42) + sum_seq = new_slt + inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True) + inputs["candidate_lens"] = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) + + # Draft tokens + draft_tokens = [100, 200, 300, 400, 500, 600, 700] + for i, token in enumerate(draft_tokens): + inputs["step_input_ids"][0, i + 1] = token + + # Position 0: draft NOT in candidates, but top-2 matches draft + inputs["candidate_ids"][0] = [999, 100, 998, 997] + # Positions 1,2: top-1 matches next draft tokens + inputs["candidate_ids"][1] = [200, 888, 777, 666] + inputs["candidate_ids"][2] = [300, 555, 444, 333] + inputs["candidate_lens"][:3] = 4 + inputs["is_block_step"] = np.zeros(real_bsz, dtype=bool) + + self._run_and_compare(inputs, label="verify_window_fallback") + + def test_topp_verify_window_no_fallback(self): + """Test TOPP when verify_window fallback does NOT trigger.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=1, + max_draft_tokens=5, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=4, + verify_window=2, + seed=42, + ) + + inputs["step_input_ids"][0, 1:] = [999, 998, 997, 996] + inputs["candidate_ids"][:] = 0 + inputs["candidate_ids"][0] = [1, 2, 3, 4] + inputs["candidate_lens"][0] = 4 + inputs["seq_lens_this_time"][0] = 5 + + self._run_and_compare(inputs, label="verify_window_no_fallback") + + def test_stop_flags_skip(self): + """Test that sequences with stop_flags=True are skipped (output_len=0).""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = [True, False, True, False] + self._run_and_compare(inputs, label="stop_flags_skip") + # Double-check stopped sequences produce output_len=0 + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + gpu_len = paddle_inputs["step_output_len"].numpy() + self.assertEqual(gpu_len[0], 0, "stopped seq bid=0 should have output_len=0") + self.assertEqual(gpu_len[2], 0, "stopped seq bid=2 should have output_len=0") + + def test_prefill_skip(self): + """Test that prefill requests (seq_lens_encoder != 0) skip Phase 1, only output 1 token.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Set bid 0 and 2 as prefill requests + inputs["seq_lens_encoder"][0] = 10 + inputs["seq_lens_encoder"][2] = 5 + self._run_and_compare(inputs, label="prefill_skip") + + def test_reasoning_status_skip(self): + """Test that reasoning_status=1 skips Phase 1, only outputs 1 token.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Set bid 1 and 3 as reasoning mode + inputs["reasoning_status"][1] = 1 + inputs["reasoning_status"][3] = 1 + self._run_and_compare(inputs, label="reasoning_status_skip") + + def test_reject_all_and_accept_all_priority(self): + """Test that reject_all takes priority over accept_all when both are True.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, + max_draft_tokens=5, + verify_strategy=VerifyStrategy.GREEDY.value, + seed=42, + match_ratio=1.0, + reject_all=True, + accept_all=True, + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + self._run_and_compare(inputs, label="reject_all_and_accept_all") + # All sequences should produce exactly 1 token (Phase 2 only) + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + gpu_len = paddle_inputs["step_output_len"].numpy() + for bid in range(4): + self.assertEqual(gpu_len[bid], 1, f"reject_all should produce exactly 1 token at bid={bid}") + + def test_mixed_batch_heterogeneous(self): + """Test a batch with mixed states: normal, stopped, prefill, reasoning, block_step.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=6, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=0.8 + ) + # bid 0: normal decode + inputs["is_block_step"][0] = False + inputs["stop_flags"][0] = False + inputs["seq_lens_encoder"][0] = 0 + inputs["reasoning_status"][0] = 0 + # bid 1: stopped + inputs["is_block_step"][1] = False + inputs["stop_flags"][1] = True + inputs["seq_lens_encoder"][1] = 0 + inputs["reasoning_status"][1] = 0 + # bid 2: prefill + inputs["is_block_step"][2] = False + inputs["stop_flags"][2] = False + inputs["seq_lens_encoder"][2] = 8 + inputs["reasoning_status"][2] = 0 + # bid 3: reasoning mode + inputs["is_block_step"][3] = False + inputs["stop_flags"][3] = False + inputs["seq_lens_encoder"][3] = 0 + inputs["reasoning_status"][3] = 1 + # bid 4: block step + inputs["is_block_step"][4] = True + inputs["stop_flags"][4] = False + inputs["seq_lens_encoder"][4] = 0 + inputs["reasoning_status"][4] = 0 + # bid 5: normal decode + inputs["is_block_step"][5] = False + inputs["stop_flags"][5] = False + inputs["seq_lens_encoder"][5] = 0 + inputs["reasoning_status"][5] = 0 + self._run_and_compare(inputs, label="mixed_batch_heterogeneous") + + def test_single_token_sequence(self): + """Test seq_lens_this_time=1: Phase 1 is skipped entirely, only Phase 2 outputs 1 token.""" + for strategy in [VerifyStrategy.GREEDY.value, VerifyStrategy.TOPP.value, VerifyStrategy.TARGET_MATCH.value]: + with self.subTest(strategy=strategy): + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=8, verify_strategy=strategy, seed=42 + ) + inputs["seq_lens_this_time"][:] = 1 + # Recompute cu_seqlens_q_output for all-1 seq_lens + inputs["cu_seqlens_q_output"] = np.array([0, 1, 2, 3], dtype=np.int32) + # Regenerate target/candidate arrays for new sum_seq=4 + sum_seq = 4 + rng = np.random.default_rng(42) + if strategy in (1, 2): + inputs["target_tokens"] = rng.integers(0, 1000, size=(sum_seq,), dtype=np.int64) + else: + max_candidate_len = 8 + inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True) + inputs["candidate_lens"] = rng.integers(1, max_candidate_len + 1, size=sum_seq, dtype=np.int32) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + self._run_and_compare(inputs, label=f"single_token_strategy_{strategy}") + + def test_max_dec_len_exact_boundary(self): + """Test step_idx == max_dec_len - 1: first emit triggers max_len_hit immediately.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=6, verify_strategy=VerifyStrategy.GREEDY.value, seed=42, match_ratio=1.0 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Set step_idx = max_dec_len - 1, so first emit_token increments past max_dec_len + inputs["max_dec_len"][:] = 50 + inputs["step_idx"][:] = 49 + # Ensure no accidental EOS in draft tokens + for bid in range(4): + for j in range(6): + while inputs["step_input_ids"][bid, j] in inputs["end_tokens"]: + inputs["step_input_ids"][bid, j] = (inputs["step_input_ids"][bid, j] + 1) % 1000 + self._run_and_compare(inputs, label="max_dec_len_exact_boundary") + # All sequences should produce exactly 1 token (first emit triggers stop) + paddle_inputs = to_paddle_inputs(inputs) + run_kernel(paddle_inputs, inputs) + gpu_len = paddle_inputs["step_output_len"].numpy() + for bid in range(4): + self.assertEqual(gpu_len[bid], 1, f"max_dec_len boundary should produce 1 token at bid={bid}") + + def test_eos_during_verify_window_bulk_accept(self): + """Test EOS token in the middle of verify_window bulk-accept range stops correctly.""" + real_bsz, max_draft_tokens, max_candidate_len, verify_window = 1, 10, 4, 2 + inputs = gen_verify_draft_tokens_inputs( + real_bsz=real_bsz, + max_draft_tokens=max_draft_tokens, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=max_candidate_len, + verify_window=verify_window, + seed=42, + ) + + new_slt = max_draft_tokens + inputs["seq_lens_this_time"] = np.array([new_slt], dtype=np.int32) + inputs["cu_seqlens_q_output"] = np.array([0], dtype=np.int32) + + rng = np.random.default_rng(42) + sum_seq = new_slt + inputs["candidate_ids"] = rng.integers(0, 1000, size=(sum_seq, max_candidate_len), dtype=np.int64) + inputs["candidate_scores"] = rng.random(size=(sum_seq, max_candidate_len)).astype(np.float32) + inputs["candidate_scores"] /= inputs["candidate_scores"].sum(axis=1, keepdims=True) + inputs["candidate_lens"] = np.full(sum_seq, max_candidate_len, dtype=np.int32) + inputs["is_block_step"] = np.zeros(real_bsz, dtype=bool) + inputs["stop_flags"] = np.zeros(real_bsz, dtype=bool) + inputs["max_dec_len"][:] = 200 + + eos_token = int(inputs["end_tokens"][0]) + # Draft tokens: 100, 200, EOS, 400, 500, ... + draft_tokens = [100, 200, eos_token, 400, 500, 600, 700, 800, 900] + for i, token in enumerate(draft_tokens): + inputs["step_input_ids"][0, i + 1] = token + + # Position 0: draft NOT in top-1, but top-2 matches draft -> verify_window triggers + inputs["candidate_ids"][0] = [999, 100, 998, 997] + # Position 1: top-1 matches next draft + inputs["candidate_ids"][1] = [200, 888, 777, 666] + # Position 2: top-1 matches next draft (which is EOS) + inputs["candidate_ids"][2] = [eos_token, 555, 444, 333] + # Position 3 onwards: top-1 matches (shouldn't be reached due to EOS) + inputs["candidate_ids"][3] = [400, 222, 111, 100] + + self._run_and_compare(inputs, label="eos_during_verify_window") + + def test_topp_max_candidate_len_1(self): + """Test TOPP with max_candidate_len=1: verify_window fallback cannot trigger.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, + max_draft_tokens=6, + verify_strategy=VerifyStrategy.TOPP.value, + max_candidate_len=1, + verify_window=2, + seed=42, + match_ratio=0.5, + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + self._run_and_compare(inputs, label="topp_max_candidate_len_1") + + def test_phase2_eos_token(self): + """Test Phase 2 target token is an EOS token.""" + inputs = gen_verify_draft_tokens_inputs( + real_bsz=4, max_draft_tokens=5, verify_strategy=VerifyStrategy.GREEDY.value, seed=42 + ) + inputs["is_block_step"][:] = False + inputs["stop_flags"][:] = False + # Make all draft tokens NOT match target (all reject at position 0) + inputs["step_input_ids"][:, 1:] = 999 + if inputs["target_tokens"] is not None: + inputs["target_tokens"][:] = 888 + # Now set the Phase 2 token (target_tokens at position 0 for each bid) to EOS + eos_token = int(inputs["end_tokens"][0]) + offset = 0 + for bid in range(4): + inputs["target_tokens"][offset] = eos_token + offset += int(inputs["seq_lens_this_time"][bid]) + self._run_and_compare(inputs, label="phase2_eos_token") + + +if __name__ == "__main__": + unittest.main() diff --git a/fastdeploy/model_executor/forward_meta.py b/fastdeploy/model_executor/forward_meta.py index d6a448a2693..35e7e9c2dec 100644 --- a/fastdeploy/model_executor/forward_meta.py +++ b/fastdeploy/model_executor/forward_meta.py @@ -273,6 +273,7 @@ class XPUForwardMeta(ForwardMeta): hidden_states: Optional[paddle.Tensor] = None is_draft: bool = False + is_speculative: bool = False # max bs max_num_seqs: int = 0 diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index 08a33c11096..1afb9493897 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -1044,17 +1044,129 @@ def forward_cuda( sampler_output.cu_batch_token_offset = cu_batch_token_offset.cpu() return sampler_output - def forward_xpu( + def _normal_sample_xpu( + self, + logits: paddle.Tensor, + probs: paddle.Tensor, + sampling_metadata: SamplingMetadata, + share_inputs: List[paddle.Tensor], + ) -> SamplerOutput: + """Normal sampling for NAIVE mode on XPU.""" + top_p, top_k, topp_seed = padding_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), + paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), + ) + _, next_tokens = top_k_top_p_sampling( + probs, + top_p=top_p, + top_k=top_k, + top_k_list=sampling_metadata.top_k_list, + topp_seed=topp_seed, + ) + real_bsz = share_inputs["seq_lens_this_time"].shape[0] + running_mask = (paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]) > 0).cast("int32") + share_inputs["accept_tokens"][:real_bsz, 0] = next_tokens.squeeze(-1) + share_inputs["accept_num"][:real_bsz] = running_mask + return SamplerOutput( + sampled_token_ids=share_inputs["accept_tokens"], + logprobs_tensors=None, + token_num_per_batch=share_inputs["accept_num"], + logits=logits, + ) + + def _verify_and_sample_xpu( self, logits: paddle.Tensor, + probs: paddle.Tensor, sampling_metadata: SamplingMetadata, max_model_len: int, share_inputs: List[paddle.Tensor], accept_all_drafts: bool = False, reject_all_drafts: bool = False, - ) -> paddle.Tensor: - from fastdeploy.model_executor.ops.xpu import speculate_verify, top_p_candidates + ) -> SamplerOutput: + """Verify draft tokens (MTP/Ngram mode) on XPU using verify_draft_tokens.""" + from fastdeploy.model_executor.ops.xpu import ( + top_p_candidates, + verify_draft_tokens, + ) + + target_tokens = None + candidate_ids, candidate_scores, candidate_lens = None, None, None + if self.verify_strategy == VerifyStrategy.TARGET_MATCH: + top_p, top_k, topp_seed = padding_sampling_params( + sampling_metadata.top_p, + sampling_metadata.top_k, + sampling_metadata.seed, + paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), + paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), + ) + _, target_tokens = top_k_top_p_sampling( + probs, + top_p=top_p, + top_k=top_k, + top_k_list=sampling_metadata.top_k_list, + topp_seed=topp_seed, + ) + elif self.verify_strategy == VerifyStrategy.GREEDY: + target_tokens = paddle.argmax(probs, axis=-1) + elif self.verify_strategy == VerifyStrategy.TOPP: + candidate_scores, candidate_ids, candidate_lens = top_p_candidates( + probs, + sampling_metadata.top_p, + share_inputs["batch_id_per_token_output"], + self.speculative_max_candidate_len, + max_model_len, + ) + else: + raise ValueError(f"Unknown verify strategy: {self.verify_strategy}") + + final_accept_all = self.config_accept_all or accept_all_drafts + final_reject_all = self.config_reject_all or reject_all_drafts or self.speculative_benchmark_mode + + verify_draft_tokens( + share_inputs["accept_tokens"], + share_inputs["accept_num"], + share_inputs["draft_tokens"], + target_tokens, + candidate_ids, + candidate_scores, + candidate_lens, + sampling_metadata.top_p, + share_inputs["stop_flags"], + share_inputs["seq_lens_encoder"], + share_inputs["seq_lens_this_time"], + sampling_metadata.eos_token_ids, + share_inputs["is_block_step"], + share_inputs["cu_seqlens_q_output"], + share_inputs["reasoning_status"], + share_inputs["max_dec_len"], + share_inputs["step_idx"], + max_model_len, + self.speculative_verify_window, + self.verify_strategy.value, + final_reject_all, + final_accept_all, + ) + return SamplerOutput( + sampled_token_ids=share_inputs["accept_tokens"], + logprobs_tensors=None, + token_num_per_batch=share_inputs["accept_num"], + logits=logits, + ) + + def forward_xpu( + self, + logits: paddle.Tensor, + sampling_metadata: SamplingMetadata, + max_model_len: int, + share_inputs: List[paddle.Tensor], + accept_all_drafts: bool = False, + reject_all_drafts: bool = False, + ) -> SamplerOutput: logits = apply_speculative_penalty_multi_scores( sampling_metadata.token_ids_all, sampling_metadata.prompt_lens, @@ -1077,61 +1189,19 @@ def forward_xpu( probs = F.softmax(logits) - top_p, top_k, topp_seed = padding_sampling_params( - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.seed, - paddle.reshape(share_inputs["seq_lens_this_time"], shape=[-1]), - paddle.reshape(share_inputs["seq_lens_encoder"], shape=[-1]), - ) - _, sampled_token_ids = top_k_top_p_sampling( - probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed - ) - - verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( - probs, - sampling_metadata.top_p, - share_inputs["batch_id_per_token_output"], - self.speculative_max_candidate_len, - max_model_len, - ) - - speculate_verify( - sampled_token_ids, - share_inputs["accept_tokens"], - share_inputs["accept_num"], - share_inputs["step_idx"], - share_inputs["stop_flags"], - share_inputs["seq_lens_encoder"], - share_inputs["seq_lens_decoder"], - share_inputs[ - "draft_tokens" - ], # Both input and output, need to write the last 1 token accepted to position 0. - share_inputs["seq_lens_this_time"], - verify_tokens, - verify_scores, - share_inputs["max_dec_len"], - sampling_metadata.eos_token_ids, - share_inputs["is_block_step"], - share_inputs["cu_seqlens_q_output"], - actual_candidate_len, - share_inputs["actual_draft_token_num"], - sampling_metadata.top_p, - max_model_len, - self.speculative_verify_window, - True, # enable_topp - (self.speculative_benchmark_mode or reject_all_drafts), - accept_all_drafts, - ) - # TODO(chenhuan09): support return logprobs - token_ids = share_inputs["accept_tokens"] - sampler_output = SamplerOutput( - sampled_token_ids=token_ids, - logprobs_tensors=None, - token_num_per_batch=share_inputs["accept_num"], - cu_batch_token_offset=None, - ) - return sampler_output + is_naive = self.spec_method is None or self.spec_method == SpecMethod.NAIVE + if is_naive: + return self._normal_sample_xpu(logits, probs, sampling_metadata, share_inputs) + else: + return self._verify_and_sample_xpu( + logits, + probs, + sampling_metadata, + max_model_len, + share_inputs, + accept_all_drafts, + reject_all_drafts, + ) class MTPSampler(nn.Layer): diff --git a/fastdeploy/model_executor/xpu_pre_and_post_process.py b/fastdeploy/model_executor/xpu_pre_and_post_process.py index 9e32ea34876..e55bb904fb7 100644 --- a/fastdeploy/model_executor/xpu_pre_and_post_process.py +++ b/fastdeploy/model_executor/xpu_pre_and_post_process.py @@ -43,12 +43,11 @@ speculate_pre_process, speculate_save_output, speculate_set_stop_value_multi_seqs, - speculate_set_value_by_flags_and_idx, speculate_step_paddle, speculate_step_reschedule, speculate_step_system_cache, - speculate_update, step_paddle, + unified_update_model_status, update_inputs, update_inputs_v1, ) @@ -150,6 +149,7 @@ def xpu_pre_process( block_tables=share_inputs["block_tables"], caches=share_inputs["caches"], max_num_seqs=share_inputs["seq_lens_this_time"].shape[0], + is_speculative=use_speculate_method, ) ( @@ -220,11 +220,6 @@ def xpu_process_output( ) -> paddle.Tensor: """ """ - if isinstance(share_inputs, dict): - output_padding_offset = share_inputs.get("output_padding_offset", None) - else: - output_padding_offset = getattr(share_inputs, "output_padding_offset", None) - hiddden_states = gather_next_token( forward_output, xpu_forward_meta.encoder_seq_lod, @@ -236,7 +231,7 @@ def xpu_process_output( xpu_forward_meta.encoder_batch_map_cpu, xpu_forward_meta.decoder_batch_map_cpu, xpu_forward_meta.len_info_cpu, - output_padding_offset, # output_padding_offset + xpu_forward_meta.is_speculative, xpu_forward_meta.max_num_seqs, ) return hiddden_states @@ -387,6 +382,8 @@ def xpu_post_process_specualate( share_inputs: Dict[str, paddle.Tensor], save_each_rank: bool = False, skip_save_output: bool = False, + is_naive_mode: bool = False, + prefill_one_step_stop: bool = False, ): """""" @@ -403,7 +400,7 @@ def xpu_post_process_specualate( model_output.min_tokens, ) - speculate_update( + unified_update_model_status( model_output.seq_lens_encoder, model_output.seq_lens_decoder, model_output.not_need_stop, @@ -415,6 +412,13 @@ def xpu_post_process_specualate( model_output.seq_lens_this_time, model_output.is_block_step, model_output.mask_rollback, + model_output.pre_ids, + model_output.prompt_lens, + model_output.step_idx, + model_output.eos_token_id, + model_output.max_dec_len, + is_naive_mode, + prefill_one_step_stop, ) if not skip_save_output: if sampler_output.logprobs_tensors is None: @@ -435,18 +439,6 @@ def xpu_post_process_specualate( speculate_clear_accept_nums(model_output.accept_num, model_output.seq_lens_decoder) - # Update pre_ids through accept tokens - speculate_set_value_by_flags_and_idx( - model_output.pre_ids, - model_output.accept_tokens, - model_output.accept_num, - model_output.stop_flags, - model_output.seq_lens_this_time, - model_output.seq_lens_encoder, - model_output.seq_lens_decoder, - model_output.step_idx, - ) - def step_xpu( share_inputs: Dict[str, paddle.Tensor], diff --git a/fastdeploy/worker/input_batch.py b/fastdeploy/worker/input_batch.py index 363dfb63097..7961920dd98 100644 --- a/fastdeploy/worker/input_batch.py +++ b/fastdeploy/worker/input_batch.py @@ -296,8 +296,8 @@ def init_share_inputs(self): dtype="int32", ) else: - self.output_cum_offsets = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.output_padding_offset = paddle.full( + self.cu_seqlens_q_output = paddle.full(shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32") + self.batch_id_per_token_output = paddle.full( shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32", @@ -439,7 +439,7 @@ def swap_data(tensor, idx1, idx2): if current_platform.is_cuda(): swap_data(self.cu_seqlens_q_output, i1, i2) else: - swap_data(self.output_cum_offsets, i1, i2) + swap_data(self.cu_seqlens_q_output, i1, i2) swap_data(self.step_draft_tokens, i1, i2) swap_data(self.step_seq_lens_this_time, i1, i2) swap_data(self.draft_logits, i1, i2) @@ -630,8 +630,8 @@ def reset_share_inputs(self): fill_paddle_tensor(self, "accept_num", 0) fill_paddle_tensor(self, "draft_tokens", -1) fill_paddle_tensor(self, "actual_draft_token_num", max_draft_token_num) - fill_paddle_tensor(self, "output_cum_offsets", 0) - fill_paddle_tensor(self, "output_padding_offset", 0) + fill_paddle_tensor(self, "cu_seqlens_q_output", 0) + fill_paddle_tensor(self, "batch_id_per_token_output", 0) fill_paddle_tensor(self, "step_draft_tokens", 0) fill_paddle_tensor(self, "step_seq_lens_this_time", 0) fill_paddle_tensor(self, "draft_logits", -1) @@ -741,8 +741,8 @@ def init_share_inputs(self): self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.token_ids_all = None else: - self.output_cum_offsets = paddle.clone(self.target_model_input_batch["output_cum_offsets"]) - self.output_padding_offset = paddle.clone(self.target_model_input_batch["output_padding_offset"]) + self.cu_seqlens_q_output = paddle.clone(self.target_model_input_batch["cu_seqlens_q_output"]) + self.batch_id_per_token_output = paddle.clone(self.target_model_input_batch["batch_id_per_token_output"]) self.pre_ids = paddle.clone(self.target_model_input_batch["pre_ids"]) self.ids_remove_padding = paddle.clone(self.target_model_input_batch["ids_remove_padding"]) self.batch_id_per_token = paddle.clone(self.target_model_input_batch["batch_id_per_token"]) diff --git a/fastdeploy/worker/xpu_model_runner.py b/fastdeploy/worker/xpu_model_runner.py index 1446257d3ae..7a3cc3c0094 100644 --- a/fastdeploy/worker/xpu_model_runner.py +++ b/fastdeploy/worker/xpu_model_runner.py @@ -135,8 +135,8 @@ def __init__( self.encoder_cache = None self.device_id = device_id - self.speculative_method = self.fd_config.speculative_config.method - self.speculative_decoding = self.speculative_method is not None + self.spec_method = self.fd_config.speculative_config.method + self.speculative_decoding = self.spec_method is not None # used by SamplingMetadata self.enable_logprob = fd_config.model_config.enable_logprob # fd_config.model_config.enable_logprob @@ -728,7 +728,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int): if has_prefill_task or has_decode_task: self.share_inputs["not_need_stop"][0] = True - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.proposer.insert_tasks_v1(req_dicts, num_running_requests) def insert_prefill_inputs(self, req_dicts: List[Request], num_running_requests: int): @@ -877,7 +877,7 @@ def get_attr_from_request(request, attr, default_value=None): self.share_inputs["not_need_stop"][0] = True - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.share_inputs["temp_scaled_logprobs"][idx : idx + 1] = get_attr_from_request( request, "temp_scaled_logprobs", False ) @@ -1069,12 +1069,18 @@ def _init_share_inputs(self, max_num_seqs: int): fill_value=max_draft_token_num, dtype="int32", ) - self.share_inputs["output_cum_offsets"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") - self.share_inputs["output_padding_offset"] = paddle.full( + self.share_inputs["cu_seqlens_q_output"] = paddle.full( + shape=[max_num_seqs + 1, 1], fill_value=0, dtype="int32" + ) + self.share_inputs["batch_id_per_token_output"] = paddle.full( shape=[max_num_seqs * (max_draft_token_num + 1)], fill_value=0, dtype="int32", ) + # reasoning_status: per-sequence reasoning phase indicator + # 0=thinking, 1=emitting boundary, 2=response, 3=end + # verify_draft_tokens 在 reasoning_status==1 时强制拒绝所有 draft token + self.share_inputs["reasoning_status"] = paddle.full(shape=[max_num_seqs, 1], fill_value=0, dtype="int32") # For V1_KVCACHE_SCHEDULER self.share_inputs["step_draft_tokens"] = paddle.full( shape=[max_num_seqs, max_draft_token_num + 1], @@ -1437,7 +1443,7 @@ def _dummy_run( block_num=block_num, ) - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.proposer.dummy_prefill_inputs( num_tokens=num_tokens, batch_size=batch_size, @@ -1454,19 +1460,16 @@ def _init_speculative_proposer(self): """ Init speculative proposer """ - if self.speculative_method == SpecMethod.NGRAM: - # xpu not support ngram proposer now - self.proposer = None - elif self.speculative_method == SpecMethod.MTP: - self.proposer = self.speculative_method.create_proposer( - self.fd_config, - main_model=self.get_model(), - local_rank=self.local_rank, - device_id=self.device_id, - share_inputs=self.share_inputs, - ) - else: + if self.spec_method is None: self.proposer = None + return + self.proposer = self.spec_method.create_proposer( + self.fd_config, + main_model=self.get_model(), + local_rank=self.local_rank, + device_id=self.device_id, + share_inputs=self.share_inputs, + ) def _set_debug_level( self, debug_level: int = 0x1, model_forward_batch: Optional[List[Request]] = None, is_dummy_run: bool = False @@ -1641,6 +1644,8 @@ class at the server level, which is too granular for ModelRunner. self.share_inputs, self.parallel_config.data_parallel_size > 1, skip_save_output, + is_naive_mode=(self.speculative_decoding and self.proposer is None), + prefill_one_step_stop=self.parallel_config.prefill_one_step_stop, ) else: xpu_post_process_normal( @@ -1656,8 +1661,11 @@ class at the server level, which is too granular for ModelRunner. ) # 6. Draft model propose - if self.speculative_method == SpecMethod.MTP: - self.proposer.run(full_hidden_states=model_output) + if self.speculative_decoding and self.proposer is not None: + if self.spec_method == SpecMethod.MTP: + self.proposer.run(full_hidden_states=model_output) + else: + self.proposer.run(share_inputs=self.share_inputs) # 7. Updata 'infer_seed' and step_paddle() self.share_inputs["infer_seed"].add_(self.infer_seed_increment) @@ -1709,7 +1717,7 @@ def profile_run(self) -> None: """Execute a forward pass with dummy inputs to profile the memory usage of the model""" self.num_gpu_blocks = self.cache_config.total_block_num - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks, profile=True) self.initialize_kv_cache(profile=True) @@ -1731,7 +1739,7 @@ def update_share_input_block_num(self, num_gpu_blocks: int) -> None: self.num_gpu_blocks = num_gpu_blocks # Reset block table and kv cache with global block num - if self.speculative_method == SpecMethod.MTP: + if self.spec_method == SpecMethod.MTP: self.proposer.initialize_kv_cache(main_model_num_blocks=self.num_gpu_blocks) self.initialize_kv_cache()