Skip to content

grpo: add policy-gradient metrics behind metrics enum#494

Open
bigximik wants to merge 8 commits intomainfrom
grpo-metrics
Open

grpo: add policy-gradient metrics behind metrics enum#494
bigximik wants to merge 8 commits intomainfrom
grpo-metrics

Conversation

@bigximik
Copy link
Copy Markdown
Collaborator

@bigximik bigximik commented Apr 27, 2026

Summary

  • Adds GRPO training metrics on LanguageModelGRPOLossConfig, opt-in via a single metrics enum (none | basic | with_entropy).
  • basic logs: old_logprobs, ratio_new_old, ratio_new_old_sum, ratio_new_old_squared_sum, kl_new_old, clipped_ratio_fraction, advantage, max_advantage, min_advantage, num_tokens. with_entropy additionally logs per-token entropy.
  • Implementation lives in fast_llm/layers/language_model/loss/grpo.py. compute_grpo_metrics is a single @torch.compile function returning a GRPOMetrics NamedTuple; entropy reuses fused_softmax_base's already-computed exp_logits / logits_norm / sum_exp_logits (no second softmax pass).
  • Correct under TP vocab-parallel: the entropy dot-product term (exp_logits * logits_norm).sum(-1) is all-reduced over the TP group before dividing by the global sum_exp_logits, matching the pattern in _fused_cross_entropy_base_from_distribution.
  • Pipeline parallelism: the MAX/MIN advantage reductions are unsafe when ranks without this loss layer contribute the LossDef.reduce zero placeholder, so LanguageModelGRPOLoss.__init__ rejects metrics != none with pipeline_parallel > 1.

Omitted metrics

The following DeepSpeed-side metrics cannot be added because the required inputs are not available in Fast-LLM's GRPO training path:

  • reward — Fast-LLM receives only pre-computed advantages; raw per-sample rewards are never passed to the trainer.
  • ref-policy logprobs / ratio_ref / kl(new||ref) — GRPO has no reference model; these only exist in PPO-with-KL-penalty variants.
  • kl_coef, entropy_bonus_coef — scalar config constants, not per-token metrics.

Enabling

fast_llm:
  model:
    base_model:
      head:
        losses:
          grpo:
            metrics: basic         # or "with_entropy" / "none"

Test plan

  • pytest tests/layers/test_lm_losses.py — all test_grpo_metrics (single-rank) and test_lm_loss_distributed[grpo_metrics-*] (TP=2 over gloo) cases pass on CPU.
  • End-to-end: enable on a 4-node math run, verify metrics appear in W&B at step 1 with ratio_new_old/num_tokens ≈ 1, kl_new_old ≈ 0, clipped_ratio_fraction ≈ 0.

Adds GRPO metrics parity with DeepSpeed: old_logprobs, ratio, ratio_sum,
ratio_sq_sum, kl_new_old, clamp_frac, advantage, max/min_advantage,
num_tokens, and optional per-token entropy.

New files:
- fast_llm/layers/language_model/loss/pg_metrics.py: reusable
  PolicyGradientMetrics dataclass + compute_policy_gradient_metrics()
  (callable by future PPO), with chunked vocab-parallel entropy support.
- tests/layers/test_grpo_metrics.py: 8 unit tests covering single-seq,
  packed multi-seq, masked tokens, clamp fraction, entropy correctness,
  mock SDP correctness, mock vocab-parallel entropy, normalization parity.

Config additions to LanguageModelGRPOLossConfig:
- compute_extra_metrics (default False): log all non-entropy metrics
- compute_entropy_metric (default False): additionally log per-token entropy
- entropy_chunk_size (default 4096): batch chunk size for entropy pass

Normalization matches existing new_logprobs_mean: sum(v*mask/label_counts)
then divided by num_documents_in_batch. MAX/MIN use LossDef ReductionType
and correct ReduceOp so they aggregate correctly across microbatches and
SDP/sequence-parallel ranks.
Rename four metrics to match DeepSpeed's naming exactly so runs on both
backends produce comparable WandB keys:

  ratio        → ratio_new_old
  ratio_sum    → ratio_new_old_sum
  ratio_sq_sum → ratio_new_old_squared_sum
  clamp_frac   → clamp_log_ratio_new_old_indicator
@bigximik bigximik requested a review from jlamypoirier April 29, 2026 08:04
Copy link
Copy Markdown
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review with Claude. I'll implement the changes myself

PR #494 Review (v2) — grpo: add policy-gradient metrics behind compute_extra_metrics flag

Overview

Adds 10 opt-in GRPO training metrics (ratio stats, Schulman KL, advantage min/max/mean, clamp fraction, token
count) plus optional per-token entropy, gated by config flags on LanguageModelGRPOLossConfig. New pg_metrics.py
module + 8 unit tests on CPU. The math (Schulman k3 KL, masked-mean normalization, vocab-parallel softmax) is
sound.

The redundant softmax passes when extra metrics are enabled are tracked separately in #507 (consolidation of
head losses/metrics into a single fused kernel) — not blocking this PR.

Issues

  1. compute_chunked_entropy reimplements fused_softmax_base, and entropy_chunk_size duplicates num_splits

Two problems in pg_metrics.py:104-156:

  • The body of compute_chunked_entropy manually does logits-scaling, max, all-reduce-MAX, exp, sum,
    all-reduce-SUM — i.e., reproduces fused_softmax_base inline. The "we handle it manually to avoid a full-tensor
    pass" rationale doesn't hold; fused_softmax_base(logits_chunk, ...) already operates on whatever you hand it.
    After replacing the manual block, the body collapses to:
    logits_norm, exp_logits, sum_exp, _ = fused_softmax_base(logits_chunk, logits_scale_factor, group)
    entropy_chunk = sum_exp.log() - (exp_logits * logits_norm).sum(-1) / sum_exp
  • This also fixes the missing @torch.compile (fused_softmax_base already has it).
  • entropy_chunk_size duplicates _num_splits. The loss layer already chunks the batch via _prepare_target
    (loss.py:111-112) and is invoked once per split by the runner. An inner per-chunk loop inside the entropy
    function is a second, redundant memory-control knob. Remove entropy_chunk_size (config field + parameter);
    compute entropy in one pass on whatever batch slice the runner hands in. If memory is tight, raising num_splits
    is the existing answer.

This obsoletes the entropy_chunk_size valid= validation point from the previous review.

  1. MIN/MAX reductions are unsafe under pipeline parallelism — assert it out

max_advantage / min_advantage are the first uses of ReductionType.maximum/minimum in the codebase. In
LossDef.reduce (engine/base_model/config.py:173-180), pipeline ranks with empty losses lists fall through to a
torch.zeros([1]) placeholder before the pipeline-group reduce_op. With ReduceOp.MAX/MIN, that zero contaminates
the result whenever the real extremum has the opposite sign:

  • all advantages negative → max_advantage reports 0
  • all advantages positive → min_advantage reports 0

Until the placeholder is fixed properly (-inf/+inf per reduction), raise an explicit error in
LanguageModelGRPOLossConfig._validate (or in LanguageModelGRPOLoss.init) when compute_extra_metrics is on
and distributed_config.pipeline_parallel > 1. Better than silently logging zeros.

  1. Naming: clamp_log_ratio_new_old_indicator

Two issues — verbose, and "log_ratio" is wrong (the threshold 1 - eps_lo < ratio < 1 + eps_hi is on the ratio
itself, not its log). The metric is the fraction of clipped tokens, so suggest clipped_ratio_fraction (full
words, no abbreviations, accurate). Update both the metrics dataclass field and the LossDef name.

  1. Replace the two-bool config with an enum

compute_entropy_metric=True, compute_extra_metrics=False still computes and registers all base metrics — only
the desc text documents the "implies" relationship, which makes it easy to misconfigure. Replace with a single
field, e.g.:

class ExtraMetrics(enum.StrEnum):
none = "none"
basic = "basic" # ratio / KL / advantage stats / clamp fraction / token count
with_entropy = "with_entropy" # basic + per-token entropy

extra_metrics: ExtraMetrics = Field(default=ExtraMetrics.none, ...)

Then _forward_backward and get_loss_definitions branch on a single value.

  1. Tests: rewrite as an extension of tests/layers/test_lm_losses.py

The new tests/layers/test_grpo_metrics.py doesn't follow Fast-LLM testing conventions. The existing
test_lm_losses.py already contains the GRPO test scaffolding to extend:

  • _get_grpo_loss_inputs — standard input generator (correlated old/new logprobs, optional masking, dtype-aware).
  • _compare_losses_and_grads — RMS-relative comparison with project-standard tolerances via
    Assert.rms_close_relative.
  • Parametrization style (loss masking × dtype × shape) matches the rest of the suite.

The new tests should:

  • Live inside test_lm_losses.py (or a new module that imports the same helpers), not introduce a parallel
    mini-framework.
  • Reuse _get_grpo_loss_inputs instead of inlining torch.randn / torch.randint.
  • Use Assert.rms_close_relative instead of a hand-rolled _assert_close.
  • Drop the module-level device = "cuda" if ... (the helpers handle device selection).
  • Drop the # --- ASCII section dividers.
  • The test_mock_vocab_parallel_entropy case validates the math derivation but not the actual
    compute_chunked_entropy(group=...) code path — fold it into the existing DistributedTestContext-style coverage
    in test_lm_losses.py if possible, or remove if it's just shadowing the math.

Once #1 lands (entropy uses fused_softmax_base, no entropy_chunk_size), most of the chunk-size variants in
test_entropy_matches_manual go away too.

  1. PolicyGradientMetrics / pg_metrics.py — speculative abstraction

The class name and the "GRPO, PPO, …" framing in the docstring design for a hypothetical second consumer. Per
CLAUDE.md ("Don't design for hypothetical future requirements"), rename to GRPOMetrics and fold the file into
grpo.py (it's ~200 lines and only used from there). The PPO extraction can happen when a PPO loss arrives, with
full knowledge of what's actually shared. This also tees up the eventual consolidation in #507 nicely.

Summary

Math and basic structure are good. Concrete asks before merge:

  1. Drop entropy_chunk_size; rewrite compute_chunked_entropy on top of fused_softmax_base.
  2. Raise on compute_extra_metrics + PP > 1.
  3. Rename clamp_log_ratio_new_old_indicator → clipped_ratio_fraction (or similar full-word).
  4. Collapse the two bools into a single ExtraMetrics enum.
  5. Migrate tests into test_lm_losses.py, reusing _get_grpo_loss_inputs / Assert.rms_close_relative.
  6. Rename to GRPOMetrics and inline into grpo.py.

Issue #507 tracks the longer-term consolidation that would also remove the duplicate softmax pass.

jlamypoirier and others added 5 commits May 5, 2026 12:34
- Inline pg_metrics.py into grpo.py; rename to GRPOMetrics
- Drop entropy_chunk_size; reuse fused_softmax_base outputs for entropy
- Replace two bool flags with a single metrics: GRPOMetricsLevel enum
- Rename clamp_log_ratio_new_old_indicator -> clipped_ratio_fraction
- Raise on metrics enabled with pipeline_parallel > 1 (MAX/MIN reduce
  would be corrupted by the zero placeholder on empty pipeline ranks)
- Migrate tests into tests/layers/test_lm_losses.py, reusing the
  existing helpers and parametrization (single + distributed runner)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Drop stale "second softmax pass" overhead note from `metrics`
  description (entropy now reuses the base softmax outputs)
- De-mirror max/min in reference_grpo_metrics: use
  advantages[loss_mask].max()/.min() instead of the implementation's
  -inf/+inf sentinel pattern

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Align (logits, target, advantages, old_log_probabilities, ...) order
  across compute_grpo_metrics, fused_grpo_loss_forward_backward, and
  reference_grpo_metrics
- Replace **kwargs in LanguageModelGRPOLoss.__init__ with the explicit
  keyword-only signature mirroring LanguageModelLoss.__init__
- num_docs -> num_documents
- Drop the comment that restated the k3 KL formula
- Give compute_grpo_metrics the same defaults as the loss kernel
- Trim the metrics field description to category-level wording
- Always exercise varying label_counts in _test_grpo_metrics so per-token
  denominator broadcasting is covered
- reference_grpo_metrics returns GRPOMetrics; comparison loop iterates
  dataclasses.fields
- Drop name = self._name micro-rebinds; use self._name inline
- defs = super()...; defs.append(...); defs.extend(...) consistently
- Tighten _register_extra_metrics losses type to dict[str, list[Tensor]]
- Split compiled tuple-returning core from outer GRPOMetrics wrapper to
  avoid @torch.compile graph-breaks on dataclass construction
- One-line comment on the metrics gate explaining the softmax-skip

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NamedTuple is a tuple subclass that dynamo handles natively, so the
previous wrapper/inner split (added to dodge a dataclass graph-break)
collapses into one @torch.compile function. Field order now lives
exactly once — on the class.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Entropy under vocab-parallel TP was wrong: the dot-product term
  (exp_logits * logits_norm).sum(-1) summed only the local vocab slice,
  so dividing by the global sum_exp_logits gave a per-rank fragment
  instead of the full E_p[logit_norm]. All-reduce the partial sum.
- Replace the verbose pipeline-parallel guard with Assert.custom; the
  field description already explains the constraint.
- Drop the cryptic `# k3` comment.
- Match _register_extra_metrics losses annotation to the base class
  (dict | None).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jlamypoirier jlamypoirier changed the title grpo: add policy-gradient metrics behind compute_extra_metrics flag grpo: add policy-gradient metrics behind metrics enum May 5, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants