Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file.
327 changes: 327 additions & 0 deletions submissions/self_tuning/muon_torch_jax_hps/muon.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,327 @@
""""
Muon PyTorch DDP implementation.

The ``MuonDataParallel`` implementation reduce-scatters the gradients internally, distributing the parameter updates across gpus.

The current implementation flattens trailing dimension for any parameter with more than 2 dimensions.
While this is correct for 4D matrices, it might be problematic for some batched 3D parameters.
The current implementation is correct for AlgoPerf v07.0 workloads, but might require adjustments in future.
In its current stage ``MuonDataParallel`` does not support standard gradient clippin.
"""

import os
import torch
import torch.distributed as dist
from abc import ABC, abstractmethod
import logging


# Distributed settings
USE_DDP = 'RANK' in os.environ
RANK = int(os.environ['RANK']) if USE_DDP else 0
LOCAL_RANK = int(os.environ['LOCAL_RANK']) if USE_DDP else 0
WORLD_SIZE = int(os.environ["WORLD_SIZE"]) if USE_DDP else 1

# Default values for Newton-Schulz
NS_A, NS_B, NS_C = 3.4445, -4.7750, 2.0315
NS_STEPS = 5
NS_EPS = 1e-7


@torch.compile()
def zeropower_via_newtonschulz5(G, steps=NS_STEPS, eps=NS_EPS):
"""
Newton-Schulz iteration to approximally orthogonalize G.
5-th order odd polynomial to approximate sign(x) on [-1,1],
pushing singlular values to {+1,-1}.

M = U @ S @ V.T
sign(M) = U @ sign(S) @ V.T, odd matrix polynomial commutes with SVD
sign(x) ~= a*x + b*x^3 + c*x^5, x in [-1,1]
"""
if G.ndim != 2:
raise RuntimeError(f"Expected 2D tensor in N-S, found {G.ndim} instead.")
a, b, c = NS_A, NS_B, NS_C
X = G.bfloat16()
if G.size(0) > G.size(1):
X = X.T

# Ensure spectral norm is at most 1.
# Ortho(cX)=Ortho(X), so we can normalize by ||X||_2 <= ||X||_F
X /= (X.norm() + eps)

# NS iterations
for _ in range(steps):
A = X @ X.T
B = b * A + c * (A @ A)
X = a * X + B @ X

if G.size(0) > G.size(1):
X = X.T
return X


@torch.no_grad()
def muon_update(g, m, beta, nesterov, ns_steps, ns_eps):
"""Updates momentum ``m`` in-place and returns Muon update."""
m.mul_(beta).add_(g, alpha=1 - beta)

if nesterov:
g = g.mul(1-beta).add(m, alpha=beta)
else:
g = m

if g.ndim == 4:
g = g.reshape(g.size(0), -1) # flatten trailing dims on 4D params
g = zeropower_via_newtonschulz5(g, steps=ns_steps, eps=ns_eps)
if m.ndim == 4:
g = g.view(m.shape) # restore original shape

return g


def _adjust_lr_to_match_adam(lr, param_shape):
# https://arxiv.org/pdf/2502.16982
A, B = param_shape[:2]
return lr * 0.2 * (max(A, B) ** 0.5)


def _adjust_lr_spectral_norm(lr, param_shape):
# Adjust from spectral norm 1 to RMS operator norm 1
# https://arxiv.org/abs/2310.17813
fan_out, fan_in = param_shape[:2]
return lr * max(1.0, (fan_out / fan_in) ** 0.5)


def _param_to_complexity(p: torch.Tensor) -> int:
"""Compute NS complexity on p.grad."""
# Shape after flatting potential trailing dims (3D, 4D)
m, n = (p.shape[0], torch.tensor(p.shape[1:]).prod().item())
# X @ X.T complexity: m^2n
# XX.T @ XX.T complexity: m^3
# XX.TXX.T @ X complexity: m^2n
return 2 * (m ** 2) * n + m ** 3


class MuonBase(torch.optim.Optimizer, ABC):
"""
Muon optimizer - Momentum Orthogonalized by Newton-Schulz.

Abstract class.
"""

def __init__(
self,
params,
lr=0.02,
weight_decay=0.0,
beta=0.95,
nesterov=True,
ns_steps=NS_STEPS,
ns_eps=NS_EPS,
adjust_lr=None,
):
if not 0.0 <= lr:
raise ValueError(f'Invalid learning rate: {lr}')
if not 0.0 <= weight_decay:
raise ValueError(f'Invalid weight_decay: {weight_decay}')
if not 0.0 <= beta < 1.0:
raise ValueError(f'Invalid muon_beta parameter: {beta}')
if nesterov not in [True, False]:
raise ValueError(f'Invalid nesterov parameter: {nesterov}')
if not 0 < ns_steps:
raise ValueError(f'Invalid ns_steps parameter: {ns_steps}')
if not 0.0 <= ns_eps:
raise ValueError(f'Invalid ns_eps parameter: {ns_eps}')
if not adjust_lr in [None, 'spectral_norm', 'match_adam']:
raise ValueError(f'Invalid adjust_lr parameter: {adjust_lr}')

defaults = dict(
lr = lr,
weight_decay = weight_decay,
beta = beta,
nesterov = nesterov,
ns_steps = ns_steps,
ns_eps = ns_eps,
)
super().__init__(params, defaults)

if adjust_lr is None:
self._adjust_lr = lambda lr, param_shape: lr
elif adjust_lr == 'spectral_norm':
self._adjust_lr = _adjust_lr_spectral_norm
elif adjust_lr == 'match_adam':
self._adjust_lr = _adjust_lr_to_match_adam

@abstractmethod
@torch.no_grad()
def step(self, closure=None):
pass


class MuonDataParallel(MuonBase):
"""
Distributed Data Parallel Muon Pytorch implementation.

Modified from: https://github.com/KellerJordan/Muon/blob/master/muon.py#L98

For each parameter group, (sorted) parameters are processed in blocks of world_size.
Each block is distributed round-robin across ranks;
each device updates its assigned parameters locally,
then all_gather syncs the block across ranks.

We sort parameters based on the corresponding Newton-Schultz complexity,
rather then based on thier size.

``step`` structure:
- ReduceScatter gradients round-robin
- Orthogonalize gradients locally, update param
- AllGather params round-robin

Both collective operations are asynchronous,
allowing to overlap computation and communication.
We wait on reduce-scatter when updating, and wait for the all-gather
ops to finish at the end of step.

Comms: one all-gather per block -> ~#params/WORLD_SIZE comms.
Space: O(largest_param)
"""
def __init__(self, params, **kwargs):
if not isinstance(params, list):
params = list(params)

if not dist.is_initialized():
raise ValueError('Using MuonDDP in a non-distributed run.')

# Sort params to fairly distribute orthogonalization across devices
if isinstance(params[0], dict): # sort each param group individually
for group in params:
group["params"] = sorted(group["params"], key=_param_to_complexity, reverse=True)
else:
params = sorted(params, key=_param_to_complexity, reverse=True)

super().__init__(params, **kwargs)


@torch.no_grad()
def step(self, closure=None):
"""
1. ReduceScatter: process grads round-robin, ReduceScatter each block. Work handles are stored.
2. AllGather: process params round-robin, wait for ReduceScatter handles on that block.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()

# 1. ReduceScatter grads
for group in self.param_groups:
group['reduce_handles'] = []

# References to grads, ensure valid tensors for reduce_scatter
grads = [p.grad if p.grad is not None else torch.zeros_like(p)
for p in group["params"]]

# Pad grads so each reduce_scatter block is of size WORLD_SIZE.
pad = (WORLD_SIZE - len(grads) % WORLD_SIZE) % WORLD_SIZE
grads_pad = grads + [torch.zeros_like(grads[-1])] * pad

# Iterate over grads in blocks of WORLD_SIZE
for block_start in range(0, len(grads), WORLD_SIZE):
# Skip padded tensor when reducing
if block_start + RANK < len(grads):
receiv = grads_pad[block_start + RANK] # ref to p.grad
else:
receiv = torch.zeros_like(grads_pad[block_start + RANK]) # dummy buffer

# ReduceScatter this block (avg)
with torch.no_grad():
handle = dist.reduce_scatter(
receiv,
grads_pad[block_start:block_start + WORLD_SIZE],
op=dist.ReduceOp.AVG,
async_op=True
)
group['reduce_handles'].append(handle)

# 2. Update and AllGather (overlapped)
gather_handles = []
for group in self.param_groups:
lr = group['lr']
wd = group['weight_decay']
beta = group['beta']
nesterov = group['nesterov']
ns_steps = group['ns_steps']
ns_eps = group['ns_eps']
params = group['params']
reduce_handles = group['reduce_handles']

# Pad params so each all-gather block is of size WORLD_SIZE.
# list concat keeps param refs (not copies), so all_gather updates model params directly.
pad = (WORLD_SIZE - len(params) % WORLD_SIZE) % WORLD_SIZE
params_pad = params + [torch.empty_like(params[-1])] * pad

# Iterate over params in blocks of WORLD_SIZE
for block_start in range(0, len(params), WORLD_SIZE):
# Wait for grads in this block.
reduce_handles.pop(0).wait()

# Each device updates the RANK-th tensor in the block
if block_start + RANK < len(params): # skip padded tensors
p = params[block_start + RANK] # round-robin
if p.grad is None:
p.grad = torch.zeros_like(p) # ensure valid tensor for all_gather

state = self.state[p]

if len(state) == 0:
state['m'] = torch.zeros_like(p)

g = muon_update(p.grad, state['m'], beta=beta, nesterov=nesterov, ns_steps=ns_steps, ns_eps=ns_eps)

adjusted_lr = self._adjust_lr(lr, p.shape) # optionally adjust lr
p.mul_(1 - lr * wd)
p.add_(g, alpha=-adjusted_lr)

# all-gather current block of params (including padded entries)
handle = dist.all_gather(
params_pad[block_start:block_start + WORLD_SIZE],
params_pad[block_start + RANK],
async_op=True
)
gather_handles.append(handle)

# Sync point
for handle in gather_handles:
handle.wait()

return loss


def split_params_muon_adam(model):
"""Split parameters:
- Muon: all matrix params (ndim ≥ 2) except embeddings
- Adam: 1D params, all embeddings
"""
muon_params, adam_params = [], []
muon_infos, adam_infos = [], [] # for logging purposes

for n, p in model.named_parameters():
if not p.requires_grad:
continue

# Assign embeddings to Adam (wmt, criteo1tb, finewebedu_lm)
if "embedding" in n.lower() or "embed_tokens" in n.lower():
adam_params.append(p)
adam_infos.append(f'{n} (ndim={p.ndim})')
elif p.ndim >= 2:
muon_params.append(p)
muon_infos.append(f'{n} (ndim={p.ndim})')
else:
adam_params.append(p)
adam_infos.append(f'{n} (ndim={p.ndim})')

logging.info("Muon params:\n\t" + "\n\t".join(muon_infos))
logging.info("Adam params:\n\t" + "\n\t".join(adam_infos))

return muon_params, adam_params
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
I1007 11:06:18.108439 23099690366016 utils.py:33] Muon params:
module.bot_mlp.0.weight (ndim=2)
module.bot_mlp.2.weight (ndim=2)
module.bot_mlp.4.weight (ndim=2)
module.top_mlp.0.weight (ndim=2)
module.top_mlp.2.weight (ndim=2)
module.top_mlp.4.weight (ndim=2)
module.top_mlp.6.weight (ndim=2)
module.top_mlp.9.weight (ndim=2)
I1007 11:06:18.108455 22916307100736 submission_runner.py:339] Initializing checkpoint and logger.
I1007 11:06:18.108510 23099690366016 utils.py:34] Adam params:
module.embedding_chunk_0 (ndim=2)
module.embedding_chunk_1 (ndim=2)
module.embedding_chunk_2 (ndim=2)
module.embedding_chunk_3 (ndim=2)
module.bot_mlp.0.bias (ndim=1)
module.bot_mlp.2.bias (ndim=1)
module.bot_mlp.4.bias (ndim=1)
module.top_mlp.0.bias (ndim=1)
module.top_mlp.2.bias (ndim=1)
module.top_mlp.4.bias (ndim=1)
module.top_mlp.6.bias (ndim=1)
module.top_mlp.9.bias (ndim=1)
Loading
Loading