Skip to content

Feat/megatron weight sync dev#9

Open
jsw-zorro wants to merge 4 commits into
Infini-AI-Lab:devfrom
jsw-zorro:feat/megatron-weight-sync-dev
Open

Feat/megatron weight sync dev#9
jsw-zorro wants to merge 4 commits into
Infini-AI-Lab:devfrom
jsw-zorro:feat/megatron-weight-sync-dev

Conversation

@jsw-zorro
Copy link
Copy Markdown

Description

Summary

Adds end-to-end support for the Megatron-LM training backend by implementing
its weight-synchronization path to RaaS, correct under full parallelism
(TP / PP / EP / ETP / VPP, including MoE). The MegatronEngine previously
raised NotImplementedError for PP>1 / EP>1 weight transfer; this PR makes
Megatron actually trainable end-to-end and removes those restrictions.

What's in here

  • Streaming Megatron→HF weight export (models/mcore/weight_export.py):
    export_hf_named_params reconstructs the global model from Megatron's sharded
    layout and yields HF-named, HF-layout tensors one at a time (OOM-safe for
    large / MoE models), delegating the gather + conversion to mbridge
    (the same bridge the engine already uses to load/save).
  • HF-space weight buffer + delta (weight_manager.py, megatron_engine.py,
    ppo_trainer.py): the transfer buffer holds HF-layout bytes, so the existing
    sparse-delta path is correct under any parallelism — fixing a latent
    corruption where the delta was computed in mcore layout but applied by the
    receiver in HF layout.
  • Direct-DMA offload (weight_manager.py): copies gathered tensors straight
    from GPU into the pinned transfer buffer instead of via a pageable
    .to("cpu") bounce. Measured 12.6s → 0.56s (~23×) for an 8B model.
  • Docs + Docker: design note docs/en/architecture/megatron-weight-sync.md;
    an optional "Install the Megatron training backend" step in installation.md
    (Option A); and a separate docker/Dockerfile.sglang.megatron that layers
    Transformer Engine + apex on top of the SGLang image (the default image and
    FSDP/inference are unaffected).
  • Recipe: examples/math/qwen3-8b-megatron-delta — the FSDP
    qwen3-8b-m2po-delta recipe with the trainer engine switched to Megatron
    (clean FSDP-vs-Megatron A/B).
  • Tests: byte-equivalence (test_hf_export_equiv.py,
    test_megatron_hf_offload.py, test_direct_dma_offload.py) and an offload
    throughput bench (bench_offload_dma.py).

Validation

Weight export — exact bf16 match vs the HF reference checkpoint:

Config Result
Qwen3-0.6B TP=2 310 tensors, 0 mismatch
Qwen3-0.6B PP=2 311 tensors, 0 mismatch
Qwen3-0.6B TP=2 PP=2 311 tensors, 0 mismatch
Qwen3-30B-A3B TP=2 EP=2 PP=2 18867 tensors, 0 mismatch

Offload throughput (Qwen3-8B, TP=4, 16.38 GB): pageable 12.6s (1.3 GB/s) →
direct DMA 0.56s (29.3 GB/s), ~23×.

End-to-end RL (single 8-GPU node, delta TCP weight sync, DeepScaleR/M2PO):

  • Qwen3-8B (4 RaaS + 4 trainer TP=4): 25 steps, 0 errors;
    weight_transfer/delta_sparsity ~0.92; task_reward/avg 0.535 → 0.585
    (recent steps 0.61–0.64); per-step offload 0.59s.
  • Qwen3-30B-A3B MoE (TP2/PP3/EP2 trainer on 6 GPUs + SGLang TP2 on 2 GPUs):
    21 steps, 0 errors; full MoE export (18867 tensors, 61 GB) gathered across
    TP/PP/EP each step; task_reward/avg ~0.64 → 0.66 (recent steps 0.70–0.77).

Notes

  • Rebased on dev, so it includes the SGLang 0.5.12 changes already there.
  • GPU-dependent suites (export equivalence, offload bench, e2e) require
    multi-GPU hardware and were run manually (results above); CI runs only
    lint/format/docs.

jsw-zorro added 4 commits May 30, 2026 04:21
Add export_hf_named_params: a streaming generator that reconstructs the
global model from Megatron's TP/PP/EP/ETP/VPP layout and yields HF-named,
HF-layout CPU tensors one at a time (OOM-safe for large / MoE models).
The gather + mcore->HF conversion is delegated to mbridge's export_weights
(the same bridge the engine already uses to load/save); this module adds
the consumer concerns: CPU move, byte-bounded bucketing, and a
metadata-only path for transfer-buffer sizing.

This is the foundation for correct sparse weight sync under full Megatron
parallelism. The design (the "delta is computed in HF byte space"
invariant) is documented in docs/en/architecture/megatron-weight-sync.md.

The Megatron backend needs two extra compiled deps beyond the base install
(megatron-core / mbridge are already there): Transformer Engine (fused
LayerNorm + sequence parallelism) and apex (optional fused LayerNorm/Adam).
These are kept out of the default image: a separate
docker/Dockerfile.sglang.megatron layers them on top of Dockerfile.sglang,
and installation.md gains an optional "Step 5: Install the Megatron training
backend" under Option A. The FSDP backend and inference are unaffected.

Validated (exact bf16 match vs the HF reference checkpoint):
- Qwen3-0.6B     TP=2            310 tensors, 0 mismatch
- Qwen3-0.6B     PP=2            311 tensors, 0 mismatch
- Qwen3-0.6B     TP=2 PP=2       311 tensors, 0 mismatch
- Qwen3-30B-A3B  TP=2 EP=2 PP=2  18867 tensors, 0 mismatch
Replace the TP-only shard-direct weight transfer with the HF-export path:

- MegatronEngine.export_hf_named_params() / get_hf_weight_metadata() stream
  gathered HF tensors via mbridge (handles TP/PP/EP/ETP/VPP). The previous
  PP>1 / EP>1 NotImplementedError guards are removed.
- WeightManager gains "megatron_hf_meta" mode: the transfer buffer is sized
  for the full HF model and offload() streams HF tensors into the inactive
  half on the writer rank, while the gather collectives run on all ranks in
  lockstep. The sender receives megatron_metadata=None and runs the plain
  full/delta path used by FSDP. Because the buffer now holds HF-layout
  bytes, the sparse delta is computed in HF space and is correct under any
  parallelism — fixing the latent corruption where the delta was computed
  in mcore layout but applied by the receiver in HF layout.
- ppo_trainer wires the generator + HF metadata through.

The legacy CPU shard-reassembly in the sender agent is now unused for
Megatron (kept only for the deprecated megatron_metadata path).

Validated (buffer roundtrip == HF reference, bit-exact):
- Qwen3-0.6B TP=2       310 tensors, 0 mismatch, 1.19 GB
- Qwen3-0.6B TP=2 PP=2  311 tensors, 0 mismatch, 1.50 GB
Add examples/math/qwen3-8b-megatron-delta — the FSDP qwen3-8b-m2po-delta
recipe with the trainer engine switched to the Megatron backend
(backend: megatron, tensor_parallel_size: 4). Identical data, algorithm,
and weight-transfer path, so it doubles as a clean FSDP-vs-Megatron A/B.

End-to-end validation (single 8-GPU node, 4 RaaS + 4 trainer TP=4, delta
TCP weight sync, DeepScaleR/M2PO):
- Qwen3-8B (this recipe): 25 steps, 0 errors; weight_transfer/delta_sparsity
  ~0.92 (delta computed in HF space); task_reward/avg rose 0.535 (first half)
  -> 0.585 (last half), recent steps 0.61-0.64. Per-step weight offload 0.59s.
- Qwen3-30B-A3B MoE (TP2/PP3/EP2 trainer on 6 GPUs + SGLang TP2 on 2 GPUs):
  21 steps, 0 errors; full MoE export (18867 tensors, 61 GB) gathered across
  TP/PP/EP each step; task_reward/avg ~0.64 -> 0.66 (recent steps 0.70-0.77).
The Megatron HF-export offload materialized each gathered tensor in
pageable host memory via .to("cpu") before copying it into the pinned
shared-memory transfer buffer — a ~1 GB/s bounce that cost ~13s/step for
an 8B model on the RL critical path.

Now the engine yields the gathered tensors on GPU (export_hf_named_params
to_cpu=False) and WeightManager copies each tensor's uint8 view directly
into the pinned buffer slice (non_blocking=True), fenced by a single
cuda.synchronize() before the cross-rank barrier. The pinned buffer is
already cudaHostRegister'd, so this hits the PCIe DMA engine.

Copying through uint8 views on both sides keeps the copy alignment-free
(robust to mixed-dtype models) and byte-identical for contiguous sources.

Measured (Qwen3-8B, TP=4, 16.38 GB):
  pageable (old): 12.6s  (1.3 GB/s)
  direct DMA:      0.56s  (29.3 GB/s)   ~23x

Byte-equivalence verified (new buffer == old pageable path == HF
reference, bit-exact) across TP=2, PP=2, TP=2/PP=2, and MoE TP=2/EP=2/PP=2
(Qwen3-30B-A3B, 61 GB). Adds tests/test_direct_dma_offload.py (equivalence)
and tests/bench_offload_dma.py (throughput).
@jsw-zorro jsw-zorro requested a review from haizhongzheng as a code owner May 30, 2026 04:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant