Feat/megatron weight sync dev#9
Open
jsw-zorro wants to merge 4 commits into
Open
Conversation
Add export_hf_named_params: a streaming generator that reconstructs the global model from Megatron's TP/PP/EP/ETP/VPP layout and yields HF-named, HF-layout CPU tensors one at a time (OOM-safe for large / MoE models). The gather + mcore->HF conversion is delegated to mbridge's export_weights (the same bridge the engine already uses to load/save); this module adds the consumer concerns: CPU move, byte-bounded bucketing, and a metadata-only path for transfer-buffer sizing. This is the foundation for correct sparse weight sync under full Megatron parallelism. The design (the "delta is computed in HF byte space" invariant) is documented in docs/en/architecture/megatron-weight-sync.md. The Megatron backend needs two extra compiled deps beyond the base install (megatron-core / mbridge are already there): Transformer Engine (fused LayerNorm + sequence parallelism) and apex (optional fused LayerNorm/Adam). These are kept out of the default image: a separate docker/Dockerfile.sglang.megatron layers them on top of Dockerfile.sglang, and installation.md gains an optional "Step 5: Install the Megatron training backend" under Option A. The FSDP backend and inference are unaffected. Validated (exact bf16 match vs the HF reference checkpoint): - Qwen3-0.6B TP=2 310 tensors, 0 mismatch - Qwen3-0.6B PP=2 311 tensors, 0 mismatch - Qwen3-0.6B TP=2 PP=2 311 tensors, 0 mismatch - Qwen3-30B-A3B TP=2 EP=2 PP=2 18867 tensors, 0 mismatch
Replace the TP-only shard-direct weight transfer with the HF-export path: - MegatronEngine.export_hf_named_params() / get_hf_weight_metadata() stream gathered HF tensors via mbridge (handles TP/PP/EP/ETP/VPP). The previous PP>1 / EP>1 NotImplementedError guards are removed. - WeightManager gains "megatron_hf_meta" mode: the transfer buffer is sized for the full HF model and offload() streams HF tensors into the inactive half on the writer rank, while the gather collectives run on all ranks in lockstep. The sender receives megatron_metadata=None and runs the plain full/delta path used by FSDP. Because the buffer now holds HF-layout bytes, the sparse delta is computed in HF space and is correct under any parallelism — fixing the latent corruption where the delta was computed in mcore layout but applied by the receiver in HF layout. - ppo_trainer wires the generator + HF metadata through. The legacy CPU shard-reassembly in the sender agent is now unused for Megatron (kept only for the deprecated megatron_metadata path). Validated (buffer roundtrip == HF reference, bit-exact): - Qwen3-0.6B TP=2 310 tensors, 0 mismatch, 1.19 GB - Qwen3-0.6B TP=2 PP=2 311 tensors, 0 mismatch, 1.50 GB
Add examples/math/qwen3-8b-megatron-delta — the FSDP qwen3-8b-m2po-delta recipe with the trainer engine switched to the Megatron backend (backend: megatron, tensor_parallel_size: 4). Identical data, algorithm, and weight-transfer path, so it doubles as a clean FSDP-vs-Megatron A/B. End-to-end validation (single 8-GPU node, 4 RaaS + 4 trainer TP=4, delta TCP weight sync, DeepScaleR/M2PO): - Qwen3-8B (this recipe): 25 steps, 0 errors; weight_transfer/delta_sparsity ~0.92 (delta computed in HF space); task_reward/avg rose 0.535 (first half) -> 0.585 (last half), recent steps 0.61-0.64. Per-step weight offload 0.59s. - Qwen3-30B-A3B MoE (TP2/PP3/EP2 trainer on 6 GPUs + SGLang TP2 on 2 GPUs): 21 steps, 0 errors; full MoE export (18867 tensors, 61 GB) gathered across TP/PP/EP each step; task_reward/avg ~0.64 -> 0.66 (recent steps 0.70-0.77).
The Megatron HF-export offload materialized each gathered tensor in
pageable host memory via .to("cpu") before copying it into the pinned
shared-memory transfer buffer — a ~1 GB/s bounce that cost ~13s/step for
an 8B model on the RL critical path.
Now the engine yields the gathered tensors on GPU (export_hf_named_params
to_cpu=False) and WeightManager copies each tensor's uint8 view directly
into the pinned buffer slice (non_blocking=True), fenced by a single
cuda.synchronize() before the cross-rank barrier. The pinned buffer is
already cudaHostRegister'd, so this hits the PCIe DMA engine.
Copying through uint8 views on both sides keeps the copy alignment-free
(robust to mixed-dtype models) and byte-identical for contiguous sources.
Measured (Qwen3-8B, TP=4, 16.38 GB):
pageable (old): 12.6s (1.3 GB/s)
direct DMA: 0.56s (29.3 GB/s) ~23x
Byte-equivalence verified (new buffer == old pageable path == HF
reference, bit-exact) across TP=2, PP=2, TP=2/PP=2, and MoE TP=2/EP=2/PP=2
(Qwen3-30B-A3B, 61 GB). Adds tests/test_direct_dma_offload.py (equivalence)
and tests/bench_offload_dma.py (throughput).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
Summary
Adds end-to-end support for the Megatron-LM training backend by implementing
its weight-synchronization path to RaaS, correct under full parallelism
(TP / PP / EP / ETP / VPP, including MoE). The
MegatronEnginepreviouslyraised
NotImplementedErrorforPP>1/EP>1weight transfer; this PR makesMegatron actually trainable end-to-end and removes those restrictions.
What's in here
models/mcore/weight_export.py):export_hf_named_paramsreconstructs the global model from Megatron's shardedlayout and yields HF-named, HF-layout tensors one at a time (OOM-safe for
large / MoE models), delegating the gather + conversion to
mbridge(the same bridge the engine already uses to load/save).
weight_manager.py,megatron_engine.py,ppo_trainer.py): the transfer buffer holds HF-layout bytes, so the existingsparse-delta path is correct under any parallelism — fixing a latent
corruption where the delta was computed in mcore layout but applied by the
receiver in HF layout.
weight_manager.py): copies gathered tensors straightfrom GPU into the pinned transfer buffer instead of via a pageable
.to("cpu")bounce. Measured 12.6s → 0.56s (~23×) for an 8B model.docs/en/architecture/megatron-weight-sync.md;an optional "Install the Megatron training backend" step in
installation.md(Option A); and a separate
docker/Dockerfile.sglang.megatronthat layersTransformer Engine + apex on top of the SGLang image (the default image and
FSDP/inference are unaffected).
examples/math/qwen3-8b-megatron-delta— the FSDPqwen3-8b-m2po-deltarecipe with the trainer engine switched to Megatron(clean FSDP-vs-Megatron A/B).
test_hf_export_equiv.py,test_megatron_hf_offload.py,test_direct_dma_offload.py) and an offloadthroughput bench (
bench_offload_dma.py).Validation
Weight export — exact bf16 match vs the HF reference checkpoint:
Offload throughput (Qwen3-8B, TP=4, 16.38 GB): pageable 12.6s (1.3 GB/s) →
direct DMA 0.56s (29.3 GB/s), ~23×.
End-to-end RL (single 8-GPU node, delta TCP weight sync, DeepScaleR/M2PO):
weight_transfer/delta_sparsity~0.92;task_reward/avg0.535 → 0.585(recent steps 0.61–0.64); per-step offload 0.59s.
21 steps, 0 errors; full MoE export (18867 tensors, 61 GB) gathered across
TP/PP/EP each step;
task_reward/avg~0.64 → 0.66 (recent steps 0.70–0.77).Notes
dev, so it includes the SGLang 0.5.12 changes already there.multi-GPU hardware and were run manually (results above); CI runs only
lint/format/docs.