Muon torch submission#15
Merged
Merged
Conversation
|
MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅ |
…lgorithms into muon_torch
…lgorithms into muon_torch
Added finewebedu_lm.txt with model parameters for Muon and Adam.
15cd68e to
015f68d
Compare
priyakasimbeg
approved these changes
May 19, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to subscribe to this conversation on GitHub.
Already have an account?
Sign in.
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
New Submission: Muon
Submission Information
*credits to original Muon implementation from Keller Jordan
Evidence for the Submission's Performance
Muon original blogpost
Muon is Scalable for LLM Training
Modded-nanogpt
Submission Development
The Muon optimizer is primarily based on the approximate orthogonalization of weight matrices.
The employed Newton–Schulz algorithm, despite using only 5 iterations, introduces a significant slowdown compared to optimizers that update parameters independently.
For example, when updating an (N \times N) matrix in PyTorch on 1×A100-40GB, we observe that Muon can be up to ~30× slower than Adam:
Therefore, when using multiple devices, available compute should be allocated smartly by distributing the orthogonalization workload across devices, rather than simply replicating it.
To test the effectiveness of this paradigm, we compare two implementation on AlgoPerf:
MuonVanilla, where the optimization algorithm is replicated identically across devices: each rank orthogonalizes all parameters.MuonDataParallel: this design is based on the Muon GitHub repo. Each device locally updates a distinct subset of parameters, which are later all-gathered. This leverages PyTorch 2.5 support for all-gathering tensors of different shapes (docs).Compared to the original implementation, we sort parameters according to their Newton–Schulz computational complexity rather than parameter size.
We further replace the default PyTorch gradient all-reduce with a custom reduce-scatter. Since different devices update different parameter subsets, each device only requires the reduced gradients corresponding to the parameters it owns. Crucially, the scatter operation follows the block structure of the distributed Muon update. Finally, we make the all-gather operation asynchronous, enabling efficient overlap between communication and computation.
Notice that several more orthogonalization strategies are possible, and we give an overview of some of them in this diagram and in a dedicated wandb report.
Chossing a single implementaton:
We benchmark both Muon implementations (and AdamW for comparison) across AlgoPerf workloads (FineWebEdu workload was not avaialable at the time of this analysis) on 4×A100-40GB, using the batch sizes from the NadamW baseline submission.
Each run trains for 5% of the workload
step_hint, and the first 100 steps are excluded as burn-in.Distributing the orthogonalization burden across devices substantially reduces Muon overhead, bringing its runtime in line with AdamW. We observe a slow-down compared to the vanilla single-device version only on
crieto1tb, but a significant advantage onwmt.Submission Details
Backup optimizer.
We use AdamW as the backup optimizer, optimizing the following parameters with it:
wmt,criteo1tb,finewebedu_lmWe attach txt files of the resulting parameter split for each workloads for ease of inspection.
Momentum implementation.
We follow Adam-style EMA implementation of momentum, as also done in Muon official repo, and in modded-nanogpt. Notice, however, that the original formulation of Muon uses PyTorch-SGD-style momentum, and a similar implementation is followed by MoonShootAI.
3D,4D parameters.
3D parameters are flattened on the trailing dimensions and NS orthogonalization is applied. No 4D parameters are present in AlgoPerf at the time of this analysis.
Next steps
tests/).