grpo: add policy-gradient metrics behind metrics enum#494
grpo: add policy-gradient metrics behind metrics enum#494
Conversation
Adds GRPO metrics parity with DeepSpeed: old_logprobs, ratio, ratio_sum, ratio_sq_sum, kl_new_old, clamp_frac, advantage, max/min_advantage, num_tokens, and optional per-token entropy. New files: - fast_llm/layers/language_model/loss/pg_metrics.py: reusable PolicyGradientMetrics dataclass + compute_policy_gradient_metrics() (callable by future PPO), with chunked vocab-parallel entropy support. - tests/layers/test_grpo_metrics.py: 8 unit tests covering single-seq, packed multi-seq, masked tokens, clamp fraction, entropy correctness, mock SDP correctness, mock vocab-parallel entropy, normalization parity. Config additions to LanguageModelGRPOLossConfig: - compute_extra_metrics (default False): log all non-entropy metrics - compute_entropy_metric (default False): additionally log per-token entropy - entropy_chunk_size (default 4096): batch chunk size for entropy pass Normalization matches existing new_logprobs_mean: sum(v*mask/label_counts) then divided by num_documents_in_batch. MAX/MIN use LossDef ReductionType and correct ReduceOp so they aggregate correctly across microbatches and SDP/sequence-parallel ranks.
Rename four metrics to match DeepSpeed's naming exactly so runs on both backends produce comparable WandB keys: ratio → ratio_new_old ratio_sum → ratio_new_old_sum ratio_sq_sum → ratio_new_old_squared_sum clamp_frac → clamp_log_ratio_new_old_indicator
There was a problem hiding this comment.
Review with Claude. I'll implement the changes myself
PR #494 Review (v2) — grpo: add policy-gradient metrics behind compute_extra_metrics flag
Overview
Adds 10 opt-in GRPO training metrics (ratio stats, Schulman KL, advantage min/max/mean, clamp fraction, token
count) plus optional per-token entropy, gated by config flags on LanguageModelGRPOLossConfig. New pg_metrics.py
module + 8 unit tests on CPU. The math (Schulman k3 KL, masked-mean normalization, vocab-parallel softmax) is
sound.
The redundant softmax passes when extra metrics are enabled are tracked separately in #507 (consolidation of
head losses/metrics into a single fused kernel) — not blocking this PR.
Issues
- compute_chunked_entropy reimplements fused_softmax_base, and entropy_chunk_size duplicates num_splits
Two problems in pg_metrics.py:104-156:
- The body of compute_chunked_entropy manually does logits-scaling, max, all-reduce-MAX, exp, sum,
all-reduce-SUM — i.e., reproduces fused_softmax_base inline. The "we handle it manually to avoid a full-tensor
pass" rationale doesn't hold; fused_softmax_base(logits_chunk, ...) already operates on whatever you hand it.
After replacing the manual block, the body collapses to:
logits_norm, exp_logits, sum_exp, _ = fused_softmax_base(logits_chunk, logits_scale_factor, group)
entropy_chunk = sum_exp.log() - (exp_logits * logits_norm).sum(-1) / sum_exp - This also fixes the missing @torch.compile (fused_softmax_base already has it).
- entropy_chunk_size duplicates _num_splits. The loss layer already chunks the batch via _prepare_target
(loss.py:111-112) and is invoked once per split by the runner. An inner per-chunk loop inside the entropy
function is a second, redundant memory-control knob. Remove entropy_chunk_size (config field + parameter);
compute entropy in one pass on whatever batch slice the runner hands in. If memory is tight, raising num_splits
is the existing answer.
This obsoletes the entropy_chunk_size valid= validation point from the previous review.
- MIN/MAX reductions are unsafe under pipeline parallelism — assert it out
max_advantage / min_advantage are the first uses of ReductionType.maximum/minimum in the codebase. In
LossDef.reduce (engine/base_model/config.py:173-180), pipeline ranks with empty losses lists fall through to a
torch.zeros([1]) placeholder before the pipeline-group reduce_op. With ReduceOp.MAX/MIN, that zero contaminates
the result whenever the real extremum has the opposite sign:
- all advantages negative → max_advantage reports 0
- all advantages positive → min_advantage reports 0
Until the placeholder is fixed properly (-inf/+inf per reduction), raise an explicit error in
LanguageModelGRPOLossConfig._validate (or in LanguageModelGRPOLoss.init) when compute_extra_metrics is on
and distributed_config.pipeline_parallel > 1. Better than silently logging zeros.
- Naming: clamp_log_ratio_new_old_indicator
Two issues — verbose, and "log_ratio" is wrong (the threshold 1 - eps_lo < ratio < 1 + eps_hi is on the ratio
itself, not its log). The metric is the fraction of clipped tokens, so suggest clipped_ratio_fraction (full
words, no abbreviations, accurate). Update both the metrics dataclass field and the LossDef name.
- Replace the two-bool config with an enum
compute_entropy_metric=True, compute_extra_metrics=False still computes and registers all base metrics — only
the desc text documents the "implies" relationship, which makes it easy to misconfigure. Replace with a single
field, e.g.:
class ExtraMetrics(enum.StrEnum):
none = "none"
basic = "basic" # ratio / KL / advantage stats / clamp fraction / token count
with_entropy = "with_entropy" # basic + per-token entropy
extra_metrics: ExtraMetrics = Field(default=ExtraMetrics.none, ...)
Then _forward_backward and get_loss_definitions branch on a single value.
- Tests: rewrite as an extension of tests/layers/test_lm_losses.py
The new tests/layers/test_grpo_metrics.py doesn't follow Fast-LLM testing conventions. The existing
test_lm_losses.py already contains the GRPO test scaffolding to extend:
- _get_grpo_loss_inputs — standard input generator (correlated old/new logprobs, optional masking, dtype-aware).
- _compare_losses_and_grads — RMS-relative comparison with project-standard tolerances via
Assert.rms_close_relative. - Parametrization style (loss masking × dtype × shape) matches the rest of the suite.
The new tests should:
- Live inside test_lm_losses.py (or a new module that imports the same helpers), not introduce a parallel
mini-framework. - Reuse _get_grpo_loss_inputs instead of inlining torch.randn / torch.randint.
- Use Assert.rms_close_relative instead of a hand-rolled _assert_close.
- Drop the module-level device = "cuda" if ... (the helpers handle device selection).
- Drop the # --- ASCII section dividers.
- The test_mock_vocab_parallel_entropy case validates the math derivation but not the actual
compute_chunked_entropy(group=...) code path — fold it into the existing DistributedTestContext-style coverage
in test_lm_losses.py if possible, or remove if it's just shadowing the math.
Once #1 lands (entropy uses fused_softmax_base, no entropy_chunk_size), most of the chunk-size variants in
test_entropy_matches_manual go away too.
- PolicyGradientMetrics / pg_metrics.py — speculative abstraction
The class name and the "GRPO, PPO, …" framing in the docstring design for a hypothetical second consumer. Per
CLAUDE.md ("Don't design for hypothetical future requirements"), rename to GRPOMetrics and fold the file into
grpo.py (it's ~200 lines and only used from there). The PPO extraction can happen when a PPO loss arrives, with
full knowledge of what's actually shared. This also tees up the eventual consolidation in #507 nicely.
Summary
Math and basic structure are good. Concrete asks before merge:
- Drop entropy_chunk_size; rewrite compute_chunked_entropy on top of fused_softmax_base.
- Raise on compute_extra_metrics + PP > 1.
- Rename clamp_log_ratio_new_old_indicator → clipped_ratio_fraction (or similar full-word).
- Collapse the two bools into a single ExtraMetrics enum.
- Migrate tests into test_lm_losses.py, reusing _get_grpo_loss_inputs / Assert.rms_close_relative.
- Rename to GRPOMetrics and inline into grpo.py.
Issue #507 tracks the longer-term consolidation that would also remove the duplicate softmax pass.
- Inline pg_metrics.py into grpo.py; rename to GRPOMetrics - Drop entropy_chunk_size; reuse fused_softmax_base outputs for entropy - Replace two bool flags with a single metrics: GRPOMetricsLevel enum - Rename clamp_log_ratio_new_old_indicator -> clipped_ratio_fraction - Raise on metrics enabled with pipeline_parallel > 1 (MAX/MIN reduce would be corrupted by the zero placeholder on empty pipeline ranks) - Migrate tests into tests/layers/test_lm_losses.py, reusing the existing helpers and parametrization (single + distributed runner) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Drop stale "second softmax pass" overhead note from `metrics` description (entropy now reuses the base softmax outputs) - De-mirror max/min in reference_grpo_metrics: use advantages[loss_mask].max()/.min() instead of the implementation's -inf/+inf sentinel pattern Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Align (logits, target, advantages, old_log_probabilities, ...) order across compute_grpo_metrics, fused_grpo_loss_forward_backward, and reference_grpo_metrics - Replace **kwargs in LanguageModelGRPOLoss.__init__ with the explicit keyword-only signature mirroring LanguageModelLoss.__init__ - num_docs -> num_documents - Drop the comment that restated the k3 KL formula - Give compute_grpo_metrics the same defaults as the loss kernel - Trim the metrics field description to category-level wording - Always exercise varying label_counts in _test_grpo_metrics so per-token denominator broadcasting is covered - reference_grpo_metrics returns GRPOMetrics; comparison loop iterates dataclasses.fields - Drop name = self._name micro-rebinds; use self._name inline - defs = super()...; defs.append(...); defs.extend(...) consistently - Tighten _register_extra_metrics losses type to dict[str, list[Tensor]] - Split compiled tuple-returning core from outer GRPOMetrics wrapper to avoid @torch.compile graph-breaks on dataclass construction - One-line comment on the metrics gate explaining the softmax-skip Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
NamedTuple is a tuple subclass that dynamo handles natively, so the previous wrapper/inner split (added to dodge a dataclass graph-break) collapses into one @torch.compile function. Field order now lives exactly once — on the class. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- Entropy under vocab-parallel TP was wrong: the dot-product term (exp_logits * logits_norm).sum(-1) summed only the local vocab slice, so dividing by the global sum_exp_logits gave a per-rank fragment instead of the full E_p[logit_norm]. All-reduce the partial sum. - Replace the verbose pipeline-parallel guard with Assert.custom; the field description already explains the constraint. - Drop the cryptic `# k3` comment. - Match _register_extra_metrics losses annotation to the base class (dict | None). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
LanguageModelGRPOLossConfig, opt-in via a singlemetricsenum (none|basic|with_entropy).basiclogs:old_logprobs,ratio_new_old,ratio_new_old_sum,ratio_new_old_squared_sum,kl_new_old,clipped_ratio_fraction,advantage,max_advantage,min_advantage,num_tokens.with_entropyadditionally logs per-tokenentropy.fast_llm/layers/language_model/loss/grpo.py.compute_grpo_metricsis a single@torch.compilefunction returning aGRPOMetricsNamedTuple; entropy reusesfused_softmax_base's already-computedexp_logits/logits_norm/sum_exp_logits(no second softmax pass).(exp_logits * logits_norm).sum(-1)is all-reduced over the TP group before dividing by the globalsum_exp_logits, matching the pattern in_fused_cross_entropy_base_from_distribution.LossDef.reducezero placeholder, soLanguageModelGRPOLoss.__init__rejectsmetrics != nonewithpipeline_parallel > 1.Omitted metrics
The following DeepSpeed-side metrics cannot be added because the required inputs are not available in Fast-LLM's GRPO training path:
ratio_ref/kl(new||ref)— GRPO has no reference model; these only exist in PPO-with-KL-penalty variants.kl_coef,entropy_bonus_coef— scalar config constants, not per-token metrics.Enabling
Test plan
pytest tests/layers/test_lm_losses.py— alltest_grpo_metrics(single-rank) andtest_lm_loss_distributed[grpo_metrics-*](TP=2 over gloo) cases pass on CPU.ratio_new_old/num_tokens ≈ 1,kl_new_old ≈ 0,clipped_ratio_fraction ≈ 0.