Skip to content

Monolithic fused head-loss kernel for combinable losses + metrics #507

@jlamypoirier

Description

@jlamypoirier

Problem

Each language-model head loss/metric runs its own softmax pass over the full vocab. When losses are combined, the work is duplicated:

  • cross-entropy + z-loss — both compute sum_exp_logits over the same logits.
  • cross-entropy + distillation — both compute the student softmax (and additionally the teacher softmax for distillation).
  • GRPO + extra metrics (PR grpo: add policy-gradient metrics behind metrics enum #494) — fused_grpo_loss_forward_backward computes logits_norm, exp_logits, sum_exp_logits, and predicted_logits; compute_policy_gradient_metrics then recomputes all of them on the same logits. With compute_entropy_metric=True, a third softmax pass runs on top.

Each pass also issues its own tensor-parallel all-reduces on logits_max / sum_exp_logits, multiplying the comm cost.

@torch.compile does not fuse across separate decorated functions, so the redundant work is real both in compute and memory.

Proposed direction

A single "monolithic" head-loss kernel (torch.compile and/or triton) that:

  • Takes a config / flag-set describing which losses and metrics to emit (CE, z-loss, distillation, GRPO clipped objective, GRPO ratio/KL/clamp/advantage stats, entropy, ...).
  • Runs softmax (and TP all-reduce) once over the logits.
  • Emits all requested scalars and the combined gradient in one pass.

This subsumes fused_grpo_loss_forward_backward, _fused_cross_entropy_base_from_*, fused_softmax_base, compute_policy_gradient_metrics, and the entropy chunking in PR #494.

Out of scope

  • Implementation details (triton vs torch.compile, kernel layout) — left for the design phase.
  • The current PRs that hit this redundancy (notably grpo: add policy-gradient metrics behind metrics enum #494) are not blocked on this issue and should land with their existing structure; this is the longer-term consolidation.

Motivation

Surfaced during review of #494, but the underlying redundancy predates it.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions