You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Each language-model head loss/metric runs its own softmax pass over the full vocab. When losses are combined, the work is duplicated:
cross-entropy + z-loss — both compute sum_exp_logits over the same logits.
cross-entropy + distillation — both compute the student softmax (and additionally the teacher softmax for distillation).
GRPO + extra metrics (PR grpo: add policy-gradient metrics behind metrics enum #494) — fused_grpo_loss_forward_backward computes logits_norm, exp_logits, sum_exp_logits, and predicted_logits; compute_policy_gradient_metrics then recomputes all of them on the same logits. With compute_entropy_metric=True, a third softmax pass runs on top.
Each pass also issues its own tensor-parallel all-reduces on logits_max / sum_exp_logits, multiplying the comm cost.
@torch.compile does not fuse across separate decorated functions, so the redundant work is real both in compute and memory.
Proposed direction
A single "monolithic" head-loss kernel (torch.compile and/or triton) that:
Takes a config / flag-set describing which losses and metrics to emit (CE, z-loss, distillation, GRPO clipped objective, GRPO ratio/KL/clamp/advantage stats, entropy, ...).
Runs softmax (and TP all-reduce) once over the logits.
Emits all requested scalars and the combined gradient in one pass.
This subsumes fused_grpo_loss_forward_backward, _fused_cross_entropy_base_from_*, fused_softmax_base, compute_policy_gradient_metrics, and the entropy chunking in PR #494.
Out of scope
Implementation details (triton vs torch.compile, kernel layout) — left for the design phase.
Problem
Each language-model head loss/metric runs its own softmax pass over the full vocab. When losses are combined, the work is duplicated:
sum_exp_logitsover the same logits.fused_grpo_loss_forward_backwardcomputeslogits_norm,exp_logits,sum_exp_logits, andpredicted_logits;compute_policy_gradient_metricsthen recomputes all of them on the same logits. Withcompute_entropy_metric=True, a third softmax pass runs on top.Each pass also issues its own tensor-parallel all-reduces on
logits_max/sum_exp_logits, multiplying the comm cost.@torch.compiledoes not fuse across separate decorated functions, so the redundant work is real both in compute and memory.Proposed direction
A single "monolithic" head-loss kernel (torch.compile and/or triton) that:
This subsumes
fused_grpo_loss_forward_backward,_fused_cross_entropy_base_from_*,fused_softmax_base,compute_policy_gradient_metrics, and the entropy chunking in PR #494.Out of scope
Motivation
Surfaced during review of #494, but the underlying redundancy predates it.