Skip to content

Muon torch submission#15

Merged
priyakasimbeg merged 32 commits into
mlcommons:mainfrom
Niccolo-Ajroldi:muon_torch
May 19, 2026
Merged

Muon torch submission#15
priyakasimbeg merged 32 commits into
mlcommons:mainfrom
Niccolo-Ajroldi:muon_torch

Conversation

@Niccolo-Ajroldi
Copy link
Copy Markdown
Member

@Niccolo-Ajroldi Niccolo-Ajroldi commented Sep 30, 2025

New Submission: Muon

Submission Information

submission_name: "MuonTorch"
submission_folder: "submissions/self_tuning/muon_torch/"
authors: "Niccolò Ajroldi*"
affiliations: "ELLIS Institute Tübingen, Max Planck Institute for Intelligent Systems"
version: "1.0"
ruleset: "self-tuning"
framework: "PyTorch"
description: "Muon DDP implementation in PyTorch."

*credits to original Muon implementation from Keller Jordan

Evidence for the Submission's Performance

Muon original blogpost
Muon is Scalable for LLM Training
Modded-nanogpt

Submission Development

The Muon optimizer is primarily based on the approximate orthogonalization of weight matrices.
The employed Newton–Schulz algorithm, despite using only 5 iterations, introduces a significant slowdown compared to optimizers that update parameters independently.

For example, when updating an (N \times N) matrix in PyTorch on 1×A100-40GB, we observe that Muon can be up to ~30× slower than Adam:

+-------+---------+----------+----------+
|    N  | SGD (s) | Adam (s) | Muon (s) |
+-------+---------+----------+----------+
|    10 | 0.00116 |  0.00088 |  0.00662 |
|   100 | 0.00032 |  0.00044 |  0.00377 |
|  1000 | 0.00034 |  0.00045 |  0.00990 |
| 10000 | 0.00201 |  0.00480 |  0.14230 |
+-------+---------+----------+----------+

Therefore, when using multiple devices, available compute should be allocated smartly by distributing the orthogonalization workload across devices, rather than simply replicating it.
To test the effectiveness of this paradigm, we compare two implementation on AlgoPerf:

  • MuonVanilla, where the optimization algorithm is replicated identically across devices: each rank orthogonalizes all parameters.
  • MuonDataParallel: this design is based on the Muon GitHub repo. Each device locally updates a distinct subset of parameters, which are later all-gathered. This leverages PyTorch 2.5 support for all-gathering tensors of different shapes (docs).

Compared to the original implementation, we sort parameters according to their Newton–Schulz computational complexity rather than parameter size.
We further replace the default PyTorch gradient all-reduce with a custom reduce-scatter. Since different devices update different parameter subsets, each device only requires the reduced gradients corresponding to the parameters it owns. Crucially, the scatter operation follows the block structure of the distributed Muon update. Finally, we make the all-gather operation asynchronous, enabling efficient overlap between communication and computation.

Notice that several more orthogonalization strategies are possible, and we give an overview of some of them in this diagram and in a dedicated wandb report.

Chossing a single implementaton:

We benchmark both Muon implementations (and AdamW for comparison) across AlgoPerf workloads (FineWebEdu workload was not avaialable at the time of this analysis) on 4×A100-40GB, using the batch sizes from the NadamW baseline submission.
Each run trains for 5% of the workload step_hint, and the first 100 steps are excluded as burn-in.

optim_name accumulated_submission_time_min time_saved_over_vanilla_min
MuonVanilla 352.18 0.00
MuonDataParallel 334.76 17.42
AdamW 337.06 15.12
image

Distributing the orthogonalization burden across devices substantially reduces Muon overhead, bringing its runtime in line with AdamW. We observe a slow-down compared to the vanilla single-device version only on crieto1tb, but a significant advantage on wmt.

Submission Details

  • Backup optimizer.
    We use AdamW as the backup optimizer, optimizing the following parameters with it:

    • 1D params (biases, layernorm, batchnorm)
    • Embeddings of wmt, criteo1tb, finewebedu_lm

    We attach txt files of the resulting parameter split for each workloads for ease of inspection.

  • Momentum implementation.
    We follow Adam-style EMA implementation of momentum, as also done in Muon official repo, and in modded-nanogpt. Notice, however, that the original formulation of Muon uses PyTorch-SGD-style momentum, and a similar implementation is followed by MoonShootAI.

  • 3D,4D parameters.
    3D parameters are flattened on the trailing dimensions and NS orthogonalization is applied. No 4D parameters are present in AlgoPerf at the time of this analysis.

Next steps

  • Efficient DP implementation
  • Momentum with dampening
  • Update with new dropout
  • Identify which layers to optimize with Muon and which with AdamW.
  • Test equivalence on toy problem (tests/).
  • Test equivalence on AlgoPerf workloads (deterministic + fixed eval_every_steps)
  • Compare speed across implementations
    • Is the efficient DP version faster?
    • Is it worth it to manually ReduceScatter gradients?
  • Decide on a single implementation to score.

@Niccolo-Ajroldi Niccolo-Ajroldi requested a review from a team as a code owner September 30, 2025 18:54
@github-actions
Copy link
Copy Markdown

github-actions Bot commented Sep 30, 2025

MLCommons CLA bot All contributors have signed the MLCommons CLA ✍️ ✅

@Niccolo-Ajroldi Niccolo-Ajroldi changed the title muon torch vanilla DP Muon torch submission (vanilla DP) Sep 30, 2025
@Niccolo-Ajroldi Niccolo-Ajroldi changed the title Muon torch submission (vanilla DP) Muon torch submission Oct 6, 2025
@Niccolo-Ajroldi Niccolo-Ajroldi changed the title Muon torch submission Muon torch submission [WIP] Oct 6, 2025
@Niccolo-Ajroldi Niccolo-Ajroldi changed the title Muon torch submission [WIP] Muon torch submission May 18, 2026
@priyakasimbeg priyakasimbeg merged commit eaa6195 into mlcommons:main May 19, 2026
9 checks passed
@github-actions github-actions Bot locked and limited conversation to collaborators May 19, 2026
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants