[Feature] Add MLA draft model support for Eagle3 training#55
Open
[Feature] Add MLA draft model support for Eagle3 training#55
Conversation
6a1182b to
2319f08
Compare
cicirori
commented
Mar 27, 2026
| train_data_path: ../examples/data/sample_conversations.jsonl | ||
| eval_data_path: ../examples/data/eval_conversations.jsonl | ||
| eval_interval: 100 | ||
| # eval_data_path: ../examples/data/eval_conversations.jsonl |
Collaborator
Author
There was a problem hiding this comment.
there is a eval hang issue, @yubofredwang
88a6b53 to
fee3ed3
Compare
1840771 to
df94bfc
Compare
Add DeepSeek MLA attention for Eagle3 draft model training, supporting
both SDPA and flex_attention backends. This enables using MLA-based draft
models (DeepSeek-V2/V3 style) with any target model.
New files:
- torchspec/models/draft/deepseek_eagle.py: MLA attention, decoder layer,
and DeepSeekForCausalLMEagle3 draft model
- configs/draft_models/deepseek_v3_eagle3.json: DeepSeek-V3 draft config
- configs/draft_models/qwen3_8b_eagle3_mla.json: Qwen3-8B + MLA draft config
- configs/sglang_qwen3_8b_mla_draft{,_flex}.yaml: e2e training configs
- tests/test_deepseek_eagle.py: 12 unit tests (SDPA + flex, CPU + CUDA)
Also includes e2e training loop fixes for single-node runs.
df94bfc to
fdbdd38
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
convert_to_hf.py: lm_head pruning bug, draft_vocab_size validation, train-time pruning detectionChanges
torchspec/models/draft/deepseek_eagle.py—DeepSeekMLAAttention,DeepSeekMLAFlexAttention,DeepSeekDecoderLayer,DeepSeekForCausalLMEagle3torchspec/models/draft/auto.py— RegisterDeepseekV3Config→DeepSeekForCausalLMEagle3dispatchconfigs/draft_models/qwen3_8b_eagle3_mla.json— Qwen3-8B target + MLA draft config (full vocab, no train-time pruning)configs/sglang_qwen3_8b_mla_draft.yaml— E2E training config (flex_attention)tests/test_deepseek_eagle.py— 12 unit tests (shapes, gradients, config dispatch, softmax scale, TTT loop, SDPA vs flex consistency)tools/convert_to_hf.py:arange + d2t→ directd2tindexing)draft_vocab_size != vocab_sizewithout--prune-vocab(model needs t2d/d2t)Test plan
./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml— 500 steps on B200convert_to_hf.pywithout--prune-vocab→ full vocab model,from_pretrainedOKconvert_to_hf.pywith--prune-vocab --draft-vocab-size 32000→ pruned model with t2d/d2t,from_pretrainedOKE2E training result (500 steps, Qwen3-8B + MLA draft, flex_attention, B200)
HF conversion verification
FSDP2 compatibility note
MLA's bottleneck structure (down_proj → RMSNorm → up_proj) is compatible with TorchSpec's per-Linear FSDP2 sharding. FSDP2 shards parameters only — weights are all-gathered before forward, so activations (including the compressed latent and
torch.splitonkv_a_projoutput) are always full-size during computation. RMSNorm layers are notnn.Linearand remain replicated. This is the same sharding pattern used by the existing Llama attention, verified end-to-end on B200 (500 steps, loss converges to 1.274).