Skip to content

feat: add FlashQLA backend for Qwen GDN and skip selected comm memory checks#1947

Merged
zhuzilin merged 9 commits into
THUDM:mainfrom
hxy771126-design:feat/qwen-gdn-flashqla-comm-memory
May 28, 2026
Merged

feat: add FlashQLA backend for Qwen GDN and skip selected comm memory checks#1947
zhuzilin merged 9 commits into
THUDM:mainfrom
hxy771126-design:feat/qwen-gdn-flashqla-comm-memory

Conversation

@hxy771126-design
Copy link
Copy Markdown
Contributor

@hxy771126-design hxy771126-design commented May 26, 2026

Summary

This PR adds FlashQLA support for Qwen Gated DeltaNet layers and reduces distributed communication overhead by skipping the pre-communication memory guard for selected ops.

Changes included:

  • Add --qwen-gdn-backend {fla,flashqla} for Qwen3.5 / Qwen3-Next Megatron model plugins and keep fla as the default backend.
  • Load FlashQLA only when --qwen-gdn-backend flashqla is selected, and fail early with a clear error if the runtime is unsupported.
  • Install FlashQLA by default in the main Docker / conda setup paths, with GB10 kept opt-in because it needs separate validation.
  • Skip the pre-communication available_memory() check for the following selected communication ops:
    • all_gather_into_tensor
    • allgather_into_tensor_coalesced
    • barrier
    • broadcast_object_list
    • reduce_scatter_tensor
    • all_to_all_single
    • isend
    • irecv

Motivation

This PR has two goals.

First, Qwen Gated DeltaNet layers currently call the FLA chunk_gated_delta_rule directly. FlashQLA provides an optimized compatible implementation, so this PR adds a --qwen-gdn-backend {fla,flashqla} switch while keeping fla as the default backend.

Second, the current communication memory guard checks CUDA driver-visible free memory before every wrapped distributed call. The guard is useful as an OOM fallback when a communication path may need additional driver-visible workspace, but it also sits on a very hot path. In communication-heavy workloads, repeatedly calling cudaMemGetInfo before every wrapped collective can add noticeable overhead. This is related to the regression reported in #1717. The guard originally came from the final grad reduce-scatter workaround in #1133 and was later moved into the general low-level distributed wrapper in #1513.

This PR does not remove the communication memory guard globally. It keeps the guard for ops that showed clear memory-pressure signals, such as all_reduce, _allgather_base, _reduce_scatter_base, all_gather_object, broadcast, and reduce_scatter_tensor_coalesced.

For the skipped ops, we profile communication memory behavior under three training settings: fixed batch, dynamic batch, and Flux + DeepEP. The profiling tracks both PyTorch allocator changes and CUDA driver-visible free-memory drops.

The selected ops fall into two categories:

  1. Ops that do not introduce observable extra GPU memory usage in steady-state training.
  2. Ops whose visible memory drop is limited to first-use / early communication initialization, not steady-state per-call allocation.
Op Key profiling signal Observed memory behavior
all_gather_into_tensor Frequently called across three profiled settings; max driver free-memory drop is only 6 MiB; PyTorch allocated/reserved delta stays 0 MiB. No meaningful extra GPU memory usage is observed.
allgather_into_tensor_coalesced Rarely called across three profiled settings; max driver free-memory drop is 0.25 MiB; PyTorch allocated/reserved delta stays 0 MiB. The coalesced tensor payload does not create additional steady-state GPU memory pressure.
barrier Rarely called across three profiled settings; driver free-memory drop is 0 MiB; PyTorch allocated/reserved delta stays 0 MiB. No GPU memory allocation is observed.
broadcast_object_list Rarely called in fixed/dynamic batch and 0 called in Flux + DeepEP; driver free-memory drop is 0 MiB in fixed/dynamic runs. No GPU memory pressure is observed for the profiled object-list broadcasts.
reduce_scatter_tensor Frequently called only in fixed batch; driver free-memory drop is 0 MiB in dynamic batch and Flux + DeepEP; fixed batch only shows a small one-off max drop of about 70 MiB; PyTorch allocated/reserved delta stays 0 MiB. No steady-state memory growth is observed.
all_to_all_single Frequently called in fixed/dynamic batch and 0 called in Flux + DeepEP; the large visible drop, about 3.7 GiB, appears at the first MoE all-to-all stage while PyTorch allocated/reserved memory stays 0 MiB. This points to NCCL / ProcessGroup / all-to-all lazy initialization, not steady-state per-call tensor allocation.
isend / irecv Rarely called across three settings; profiling shows about 0 MiB extra driver-memory drop and PyTorch allocated/reserved memory. No meaningful extra GPU memory usage is observed.

For these ops, running available_memory() before every call does not help steady-state memory reclamation, but it adds overhead on the communication hot path.

End-to-end RL benchmark: Qwen3.5-35B-A3B

Main setup:

Item Value
Model Qwen3.5-35B-A3B
Dataset dapo-math-17k/dapo-math-17k.jsonl
Nodes / GPUs 1 node, 8 GPUs (L20Z)
RL algorithm GRPO
Reward model deepscaler
Global batch size 128
Number of rollout iterations 5
Rollout batch size 16
Samples per prompt 8
Max response length 8192
Dynamic batch size enabled
Max tokens per GPU 8192
TP / PP / CP / EP / ETP 2 / 1 / 4 / 8 / 1
Sequence parallel enabled
Recompute full / uniform / 1 layer
Optimizer Adam, lr=1e-6, CPU offload enabled
Rollout GPUs per engine 4
SGLang static memory fraction 0.5
SGLang max running requests 128

Results are averaged over the collected training / rollout metric records in each run:

Run Avg rollout tokens/GPU/s Avg ref log probs time Avg actor train time Avg train time Avg actor train tokens/s Avg actor train TFLOPS Train speedup vs baseline
Baseline: FLA backend + original comm memory check 1203.7 449.2s 1060.9s 1510.7s 989.3 2.23 1.00x
Selected comm memory-check skip only, FLA backend 1244.8 203.1s 516.9s 720.4s 2095.6 4.72 2.10x
FlashQLA backend only, original comm memory check 1232.5 380.4s 952.5s 1333.4s 1105.6 2.49 1.13x
This PR: FlashQLA backend + selected comm memory-check skip 1268.1 186.5s 371.1s 558.1s 2911.2 6.56 2.71x

Note: FlashQLA provides a smaller end-to-end gain because its GDN rule-kernel speedup is diluted by larger full-step costs such as MoE dispatch/combine, NCCL communication, optimizer offload, and rollout/training synchronization.

@hxy771126-design
Copy link
Copy Markdown
Contributor Author

@huang3eng Please help to review.

@huang3eng
Copy link
Copy Markdown
Contributor

huang3eng commented May 26, 2026

@hxy771126-design
This PR seems to include two separate optimizations:

  1. Skipping the pre-communication memory check for selected low-risk communication operators.
  2. Using FlashQLA instead of FLA for Qwen GDN.

It would be helpful to clarify the following two points:

Could you please provide the concrete performance benefit of each optimization separately?

Also, the current PR is still missing a clear definition of “low-risk communication operators”. Could you please add evidence explaining why it is safe to skip the memory check before these ops? For example, experiments showing that these operators do not allocate additional GPU memory, or that there is no steady-state CUDA driver-visible memory increase before/after these calls.

@huang3eng
Copy link
Copy Markdown
Contributor

LGTM

@zhuzilin zhuzilin merged commit b5994e6 into THUDM:main May 28, 2026
55 of 84 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants