-
Notifications
You must be signed in to change notification settings - Fork 738
[Cherry-Pick][BugFix] fix MTP bugs in TP and overlap(#7172) #7192
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -53,7 +53,9 @@ void SpeculateSaveOutMmsgTopK(const paddle::Tensor& sampled_token_ids, | |
| int message_flag, // Target: 3, Draft: 4 | ||
| int64_t rank_id, | ||
| bool save_each_rank) { | ||
| if (!save_each_rank && rank_id > 0) { | ||
| // NOTE(yaohuicong): Skip non-zero TP ranks — they share identical sampling | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上, |
||
| // outputs, so only rank 0 needs to send results to the message queue. | ||
| if (rank_id > 0) { | ||
| return; | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -345,9 +345,7 @@ def _predict_next_launch_token_num(self) -> int: | |
| is_block_step_cpu = self.share_inputs["is_block_step_cpu"].numpy() | ||
| next_real_bsz = (seq_lens_this_time_cpu > 0).sum().item() + (is_block_step_cpu > 0).sum().item() | ||
| token_num_one_step = (self.speculative_config.num_speculative_tokens + 1) if self.speculative_decoding else 1 | ||
| next_launch_token_num = ( | ||
| seq_lens_this_time_cpu.sum().item() + is_block_step_cpu.sum().item() * token_num_one_step | ||
| ) | ||
| next_launch_token_num = next_real_bsz * token_num_one_step | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🟡 建议 此简化看起来是合理的,因为在 speculate decoding 场景下,每个 sequence 在每一步处理的 token 数量应该是固定的( 不过,建议添加注释说明原始计算方式的问题,例如: # In MTP (Multi-Token Prediction) mode, each sequence processes a fixed number of
# tokens per step (num_speculative_tokens + 1), so we can simplify the calculation
# from seq_lens.sum() + is_block_step.sum() * token_num_one_step to
# next_real_bsz * token_num_one_step. |
||
| return next_launch_token_num, next_real_bsz | ||
|
|
||
| def only_prefill(self): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❓ 疑问
save_each_rank参数的语义已被移除(从if (!save_each_rank && rank_id > 0)改为if (rank_id > 0)),这意味着无论save_each_rank的值如何,只有 rank 0 会发送结果到消息队列。潜在影响:如果 speculate 场景支持 EP(Expert Parallelism)模式,在 EP + TP 混合模式下,不同 EP rank 的输出是不同的,移除此检查会导致非 rank 0 的输出丢失。
建议:
save_each_rank检查另外,XPU 版本的
speculate_save_output.cc(custom_ops/xpu_ops/src/ops/mtp/speculate_save_output.cc)仍然保留原始逻辑,建议考虑是否需要同步修改以保持跨硬件一致性。