feat: one-sided target probability acceptance for MTP drafts increases acceptance rate and throughput compared to argmax alone#8
Conversation
…s acceptance rate and throughput compared to argmax alone MTP drafters use greedy argmax internally — they do not expose a full logit distribution, by design, for speed. This change adds a further tok/s improvement by allowing users to tune the acceptance threshold, achieving ~20% throughput gains by accepting more draft tokens while retaining the ability to manually verify the threshold at which semantic breakdown occurs for their specific model/task combination. When the drafter and target model disagree on a token, rather than immediately rejecting (standard argmax behaviour), --draft-p-accept triggers a one-sided softmax check over the target model's logits for the draft token. If the target assigns p >= draft-p-accept to that token, it is accepted in place of the target's own argmax prediction and decoding continues. No drafter logits are required, keeping the drafter inference path unchanged and preserving the speed advantage of argmax-only drafting. This is intentionally lighter than the full ratio test in the MTP paper. Changes: - common/sampling.cpp: add p_accept parameter to sample_and_accept_n; on drafter/target disagreement compute softmax over target logits and accept draft token if p_target(draft_token) >= p_accept - common/sampling.h: update both overloads of sample_and_accept_n signature - common/arg.cpp: register --draft-p-accept CLI argument - common/common.h: add p_accept field to common_params_speculative struct - tools/server/server-context.cpp: wire p_accept into speculative config Usage: --draft-p-accept 0.005 # accept draft token if p_target >= 0.005 --draft-p-accept 0.0 # standard argmax-only behaviour (default)
|
Review: sampler state vs accepted token In
Suggestion: decide the chosen token first, then call Minor: full-vocabulary softmax on each mismatch is O(n_vocab); worth noting for large vocabs. Otherwise the feature direction looks useful. |
Fixes sampler state bug identified by Ooooze - previously common_sampler_accept was called with target id before p_accept check, leaving grammar FSM and gsmpl->prev tracking wrong token when draft token was substituted.
|
Thanks for catching my error! — yes we need to defer common_sampler_accept until after the p_accept resolves, otherwise stale tokens are passed to the grammar FSM. I've pushed the fix with your corrections to the PR. I have several further enhancements to this feature but feel it's more important to get the throughput gains to users quickly and avoid scope creep in this PR. Will follow up in subsequent PRs. |
Overview
MTP drafters use greedy argmax internally — they do not expose a full logit distribution, by design, for speed. This change adds further tok/s improvements by allowing users to tune the acceptance threshold, achieving ~20% throughput gains by accepting more draft tokens, The user can manually verify the threshold at which semantic breakdown occurs for their specific model/task combination.
When the drafter and target model disagree on a token, rather than immediately rejecting (standard argmax behaviour), --draft-p-accept triggers a one-sided softmax check over the target model's logits for the draft token. If the target assigns p >= draft-p-accept to that token, it is accepted in place of the target's own argmax prediction and decoding continues.
No drafter logits are required, keeping the drafter inference path unchanged and preserving the speed advantage of argmax-only drafting. This is intentionally lighter than the full ratio test in the MTP paper.
Changes:
Usage:
--draft-p-accept 0.005 # accept draft token if p_target >= 0.005
--draft-p-accept 0.0 # standard argmax-only behaviour (default)
Requirements
best test 15.5 t/s 300,000 token context