llama + spec: MTP Support #22673
Conversation
|
Nice, I think this is a fresh start better than my WIP #18886 (that I still never find the time to continue) There were some other attempts to add MTP support but they all heavily rely on host <--> device data copy. I assume you tried addressed this, right? (Maybe there was a discussion somewhere but I wasn't aware of) |
ngxson
left a comment
There was a problem hiding this comment.
(not a review, but opening some discussions)
| // number of recurrent-state snapshots per seq for rollback; tensors are widened to (1 + n_rs_seq) groups | ||
| uint32_t n_rs_seq = 0; |
There was a problem hiding this comment.
not 100% sure but maybe the naming with _seq is a bit confusing (or I'm misunderstanding this)
I imagine that we want to keep a buffer ring style of recurrent-state(s), similar to SWA in KV cache, right? if that's the case, probably better call it n_rs_window
There was a problem hiding this comment.
Yes this is partly the review comment from here #22400 (comment)
|
|
||
| for (int il = 0; il < n_layer; ++il) { | ||
| // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. | ||
| const int n_transformer_layers = n_layer - (int)hparams.nextn_predict_layers; |
There was a problem hiding this comment.
nits, but maybe call it n_main_layers, as technically nextn layer is also a transformer layer
| //TODO: generalize if this is ok, we should load <arch_name>_mtp arch? | ||
| if (params_base.speculative.type == COMMON_SPECULATIVE_TYPE_MTP) { | ||
| SRV_INF("loading MTP head from '%s' (override_arch=qwen35_mtp)\n", | ||
| params_base.model.path.c_str()); | ||
|
|
||
| auto mparams_mtp = common_model_params_to_llama(params_base); | ||
| mparams_mtp.override_arch = "qwen35_mtp"; | ||
|
|
||
| model_mtp.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp)); | ||
| if (model_mtp == nullptr) { | ||
| SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str()); | ||
| return false; | ||
| } |
There was a problem hiding this comment.
if you look at #18886, the better way is to move llama_graph_type to the public API, then load the context with the appropriate graph type
There was a problem hiding this comment.
Yes that seems like the correct way to do this if we want to support MTP in a generic way
|
@ngxson yes the h2d was discussed with GG, he's working on a refactor which will allow us to share tensors between two llama context |
|
Great work, this should massively bridge the TG gap with vLLM, or maybe even surpass it together with tensor-parallel. |
Currently speculative checkpoint needs to restart from a checkpoint after some draft tokens are not accepted, this leads to some wastage in running the target again. This PR adds the ability to rollback upto `draft_max` by storing the GDN intermediates.
|
in my opinion Qwen 3.6 is the most important thing that happened in open source models in a long time, this is going to be so valuable. ngram could be set to match only very strong and long candidates - for large repetitive paraphrasing |
|
" idea is that MTP should automatically start and we shouldn't need to distribute the MTP gguf separately but also it has it's own context/kv-cache etc." -> Does this mean MTP needs additional resources (RAM/VRAM?) If so, there should always be an option to remain to disable it. Right now on my system (6 GB VRAM, 32 GB RAM), speculative decoding just makes things much slower even on very small draft models because of that exact reason, they need own context and kv-cache. Such low to midrange systems already operate on the edge in terms of memory. |
|
I'm getting garbage responses running this PR on the Vulkan backend with an R9700 using llama-server. I'm using the GGUF you linked above. Interestingly, draft acceptance is only 0.01282. Prompt: "Hello!" |
|
@cmp-nct I'm not sure, but could be possible @Dampfinchen as of right now it is opt-in via @mbednarek360 |
|
Might it be possible/useful to run the draft model on a second GPU? Given that MTP weights model are relatively small this might provide a useful speedup on systems with a dedicated high-VRAM "AI" GPU with a cheaper low-VRAM "normal" GPU used for display output, etc... possibly prevent some degree of resource contention. |
|
Thank you, we are eagerly awaiting this to become stable, here automated test results for my machine; __
Result:
|
|
@cturan Thanks for testing, I'm aware of the issue for the prefill and will work on a fix. |
|
Might be a long shot, but any chance of supporting MTP with a reduced vocabulary? MTP layers are rather chonky and reducing token embeddings might help users with less VRAM by filtering out certain languages. Obviously the full model will still be able to produce those tokens if need be so it won't be gimped. |
|
Working on taking this for a spin with the Q4_K_M quant of Qwen3.6-35BA3B. I was gonna try to start from unsloth's quant since they already perform really well, but of course they don't have any mtp layers. @am17an Think it would work if I just "steal" the layers from your q8 quant and merge them into the unsloth quant? (add blk.40 and bump some top-level config like block_count and kv_count) |
Overview
This PR adds support for MTP (Multi Token Prediction) heads. I tested this on Qwen3.6 27B and Qwen3.6 35BA3B but in principle it should work for any MTP model. I've posted the detailed results below, but typically I see a steady-state acceptance of around 75% with 3 draft tokens, which is more than >2x speed-up over baseline. The design decisions I took to get to this stage are as follows:
ubatchPerformance
A simple bench for testing various prompts is here: https://gist.github.com/am17an/228edfb84ed082aa88e3865d6fa27090. Posting the results below:
Performance on DGX Spark 🧵
No MTP (baseline)
./llama-server -m ../qwen3.6-q8_0.gguf -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}"MTP --spec-draft-max-n 3
./llama-server -m ../qwen3.6-q8_0-mtp.gguf -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 3MTP --spec-draft-max-n 2
./llama-server -m ../qwen3.6-q8_0-mtp.gguf -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}" --spec-type mtp --spec-draft-n-max 2Draft model (Qwen3.5 0.8B) with spec-draft-n-max 16 with partial rollback
llama-server -m ../qwen3.6/Qwen3.6-27B-Q8_0.gguf -hfd unsloth/Qwen3.5-0.8B-GGUF:Q8_0 --spec-draft-n-max 16 -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}"Master with draft model with spec-draft-n-max 64 with no partial rollback
llama-server -m ../qwen3.6/Qwen3.6-27B-Q8_0.gguf -hfd unsloth/Qwen3.5-0.8B-GGUF:Q8_0 --spec-draft-n-max 64 -np 1 --chat-template-kwargs "{\"preserve_thinking\": true}"How to use
I've uploaded the GGUF which I made by using the
convert_hf_to_gguf.pychanges in this PR. Here is another GGUF for the MoE (35BA3B) modelRequirements