Skip to content

[Feature] Add MLA draft model support for Eagle3 training#55

Open
cicirori wants to merge 1 commit intomainfrom
feature/mla-draft-model
Open

[Feature] Add MLA draft model support for Eagle3 training#55
cicirori wants to merge 1 commit intomainfrom
feature/mla-draft-model

Conversation

@cicirori
Copy link
Copy Markdown
Collaborator

@cicirori cicirori commented Mar 27, 2026

Summary

  • Add DeepSeek MLA (Multi-head Latent Attention) draft model for Eagle3 online training
  • Support both SDPA and flex_attention backends
  • Enable using MLA-based draft models (DeepSeek-V2/V3 style) with any target model
  • Reference config based on nvidia/Kimi-K2.5-Thinking-Eagle3
  • Fix convert_to_hf.py: lm_head pruning bug, draft_vocab_size validation, train-time pruning detection

Changes

  • torchspec/models/draft/deepseek_eagle.pyDeepSeekMLAAttention, DeepSeekMLAFlexAttention, DeepSeekDecoderLayer, DeepSeekForCausalLMEagle3
  • torchspec/models/draft/auto.py — Register DeepseekV3ConfigDeepSeekForCausalLMEagle3 dispatch
  • configs/draft_models/qwen3_8b_eagle3_mla.json — Qwen3-8B target + MLA draft config (full vocab, no train-time pruning)
  • configs/sglang_qwen3_8b_mla_draft.yaml — E2E training config (flex_attention)
  • tests/test_deepseek_eagle.py — 12 unit tests (shapes, gradients, config dispatch, softmax scale, TTT loop, SDPA vs flex consistency)
  • tools/convert_to_hf.py:
    • Fix lm_head trimming bug (arange + d2t → direct d2t indexing)
    • Error when draft_vocab_size != vocab_size without --prune-vocab (model needs t2d/d2t)
    • Error when lm_head already pruned (train-time pruning incompatible with post-training re-prune)
  • E2E training loop fixes for single-node runs

Test plan

  • Unit tests: 12/12 passed (CPU + CUDA, SDPA + flex_attention)
  • E2E: ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml — 500 steps on B200
  • convert_to_hf.py without --prune-vocab → full vocab model, from_pretrained OK
  • convert_to_hf.py with --prune-vocab --draft-vocab-size 32000 → pruned model with t2d/d2t, from_pretrained OK
  • pre-commit (ruff + ruff-format + isort) all pass

E2E training result (500 steps, Qwen3-8B + MLA draft, flex_attention, B200)

$ ./examples/qwen3-8b-single-node/run.sh configs/sglang_qwen3_8b_mla_draft.yaml

Training: 100% | 500/500 [01:45<00:00, 4.75step/s, loss=1.274, acc=0.865, acc_len=4.15, thru=6.5, I=145.7, T=12.6, wait=0.0s, pool=0, epoch=1/1]

HF conversion verification

# Without --prune-vocab (full vocab)
$ python tools/convert_to_hf.py --input-dir outputs/.../iter_0000501 --config configs/draft_models/qwen3_8b_eagle3_mla.json -f
vocab=151936/151936, lm_head=[151936, 4096], has_vocab_pruning=False, params=1,493,714,944

# With --prune-vocab (post-training pruning)
$ python tools/convert_to_hf.py ... --prune-vocab --draft-vocab-size 32000 --dataset-path examples/data/sample_conversations.jsonl --tokenizer Qwen/Qwen3-8B --chat-template qwen --prompt-key conversations
vocab=32000/151936, lm_head=[32000, 4096], has_vocab_pruning=True, params=1,002,457,088

# Both load correctly via from_pretrained

FSDP2 compatibility note

MLA's bottleneck structure (down_proj → RMSNorm → up_proj) is compatible with TorchSpec's per-Linear FSDP2 sharding. FSDP2 shards parameters only — weights are all-gathered before forward, so activations (including the compressed latent and torch.split on kv_a_proj output) are always full-size during computation. RMSNorm layers are not nn.Linear and remain replicated. This is the same sharding pattern used by the existing Llama attention, verified end-to-end on B200 (500 steps, loss converges to 1.274).

@cicirori cicirori force-pushed the feature/mla-draft-model branch 3 times, most recently from 6a1182b to 2319f08 Compare March 27, 2026 19:20
train_data_path: ../examples/data/sample_conversations.jsonl
eval_data_path: ../examples/data/eval_conversations.jsonl
eval_interval: 100
# eval_data_path: ../examples/data/eval_conversations.jsonl
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is a eval hang issue, @yubofredwang

@cicirori cicirori force-pushed the feature/mla-draft-model branch 4 times, most recently from 88a6b53 to fee3ed3 Compare March 27, 2026 19:33
@cicirori cicirori requested a review from yubofredwang March 27, 2026 19:35
@cicirori cicirori force-pushed the feature/mla-draft-model branch 5 times, most recently from 1840771 to df94bfc Compare March 27, 2026 19:57
Add DeepSeek MLA attention for Eagle3 draft model training, supporting
both SDPA and flex_attention backends. This enables using MLA-based draft
models (DeepSeek-V2/V3 style) with any target model.

New files:
- torchspec/models/draft/deepseek_eagle.py: MLA attention, decoder layer,
  and DeepSeekForCausalLMEagle3 draft model
- configs/draft_models/deepseek_v3_eagle3.json: DeepSeek-V3 draft config
- configs/draft_models/qwen3_8b_eagle3_mla.json: Qwen3-8B + MLA draft config
- configs/sglang_qwen3_8b_mla_draft{,_flex}.yaml: e2e training configs
- tests/test_deepseek_eagle.py: 12 unit tests (SDPA + flex, CPU + CUDA)

Also includes e2e training loop fixes for single-node runs.
@cicirori cicirori force-pushed the feature/mla-draft-model branch from df94bfc to fdbdd38 Compare March 27, 2026 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant