From 98836e6dcea6e60c4e1112ca6bf51350df593784 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 27 Apr 2026 12:56:21 +0000 Subject: [PATCH 1/7] grpo: add policy-gradient metrics behind compute_extra_metrics flag Adds GRPO metrics parity with DeepSpeed: old_logprobs, ratio, ratio_sum, ratio_sq_sum, kl_new_old, clamp_frac, advantage, max/min_advantage, num_tokens, and optional per-token entropy. New files: - fast_llm/layers/language_model/loss/pg_metrics.py: reusable PolicyGradientMetrics dataclass + compute_policy_gradient_metrics() (callable by future PPO), with chunked vocab-parallel entropy support. - tests/layers/test_grpo_metrics.py: 8 unit tests covering single-seq, packed multi-seq, masked tokens, clamp fraction, entropy correctness, mock SDP correctness, mock vocab-parallel entropy, normalization parity. Config additions to LanguageModelGRPOLossConfig: - compute_extra_metrics (default False): log all non-entropy metrics - compute_entropy_metric (default False): additionally log per-token entropy - entropy_chunk_size (default 4096): batch chunk size for entropy pass Normalization matches existing new_logprobs_mean: sum(v*mask/label_counts) then divided by num_documents_in_batch. MAX/MIN use LossDef ReductionType and correct ReduceOp so they aggregate correctly across microbatches and SDP/sequence-parallel ranks. --- fast_llm/layers/language_model/loss/config.py | 15 + fast_llm/layers/language_model/loss/grpo.py | 86 +++- .../layers/language_model/loss/pg_metrics.py | 210 ++++++++++ tests/layers/test_grpo_metrics.py | 377 ++++++++++++++++++ 4 files changed, 686 insertions(+), 2 deletions(-) create mode 100644 fast_llm/layers/language_model/loss/pg_metrics.py create mode 100644 tests/layers/test_grpo_metrics.py diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 4381aa5d9..4f91724a2 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -205,6 +205,21 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc="Enable triton implementation. Default: use if available.", hint=FieldHint.expert, ) + compute_extra_metrics: bool = Field( + default=False, + desc="Log additional GRPO metrics: old_logprobs, ratio, KL(new||old), advantage stats, clamp fraction, token count.", + hint=FieldHint.feature, + ) + compute_entropy_metric: bool = Field( + default=False, + desc="Also log per-token entropy (-Σ p log p). Requires a second pass over logits (~10-20%% overhead). Implies compute_extra_metrics.", + hint=FieldHint.feature, + ) + entropy_chunk_size: int = Field( + default=4096, + desc="Batch chunk size for chunked entropy computation. Memory per chunk ∝ chunk_size × vocab_local.", + hint=FieldHint.expert, + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index cc6cbf726..4cb66522c 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -3,7 +3,7 @@ import torch -from fast_llm.engine.base_model.config import LossDef +from fast_llm.engine.base_model.config import LossDef, ReductionType from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses @@ -51,10 +51,92 @@ def _forward_backward( self._register_loss( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) + + if losses is not None and (self._config.compute_extra_metrics or self._config.compute_entropy_metric): + self._register_pg_metrics(logits, kwargs, losses, split_index) + return loss, grad + def _register_pg_metrics( + self, + logits: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict, + split_index: int, + ) -> None: + from fast_llm.layers.language_model.loss.pg_metrics import compute_policy_gradient_metrics + + metrics = compute_policy_gradient_metrics( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index), + self._config.epsilon_low, + self._config.epsilon_high, + self._logits_scale_factor, + vocab_parallel_group=self._parallel_dim.group if self._vocab_parallel else None, + compute_entropy=self._config.compute_entropy_metric, + entropy_chunk_size=self._config.entropy_chunk_size, + ) + + num_docs = kwargs[LanguageModelKwargs.num_documents_in_batch] + name = self._name + + # Per-token mean metrics: divide by num_docs to match new_logprobs_mean normalization. + for attr, suffix in ( + ("old_logprobs", "old_logprobs"), + ("ratio", "ratio"), + ("kl_new_old", "kl_new_old"), + ("clamp_frac", "clamp_frac"), + ("advantage", "advantage"), + ): + self._register_loss(f"{name}_{suffix}", getattr(metrics, attr) / num_docs, losses) + + # Raw sum metrics (no per-doc normalization). + for attr, suffix in ( + ("ratio_sum", "ratio_sum"), + ("ratio_sq_sum", "ratio_sq_sum"), + ("num_tokens", "num_tokens"), + ): + self._register_loss(f"{name}_{suffix}", getattr(metrics, attr), losses) + + # MAX/MIN metrics: pass correct reduce_op for sequence-parallel mode. + self._register_loss( + f"{name}_max_advantage", + metrics.max_advantage, + losses, + reduce_op=torch.distributed.ReduceOp.MAX, + ) + self._register_loss( + f"{name}_min_advantage", + metrics.min_advantage, + losses, + reduce_op=torch.distributed.ReduceOp.MIN, + ) + + if metrics.entropy is not None: + self._register_loss(f"{name}_entropy", metrics.entropy / num_docs, losses) + def get_loss_definitions(self) -> list[LossDef]: - return super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] + defs = super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] + if self._config.compute_extra_metrics or self._config.compute_entropy_metric: + name = self._name + defs += [ + LossDef(f"{name}_old_logprobs"), + LossDef(f"{name}_ratio"), + LossDef(f"{name}_ratio_sum"), + LossDef(f"{name}_ratio_sq_sum"), + LossDef(f"{name}_kl_new_old"), + LossDef(f"{name}_clamp_frac"), + LossDef(f"{name}_advantage"), + LossDef(f"{name}_max_advantage", reduction=ReductionType.maximum), + LossDef(f"{name}_min_advantage", reduction=ReductionType.minimum), + LossDef(f"{name}_num_tokens"), + ] + if self._config.compute_entropy_metric: + defs.append(LossDef(f"{name}_entropy")) + return defs def get_preprocessing_config( self, diff --git a/fast_llm/layers/language_model/loss/pg_metrics.py b/fast_llm/layers/language_model/loss/pg_metrics.py new file mode 100644 index 000000000..72c8c811a --- /dev/null +++ b/fast_llm/layers/language_model/loss/pg_metrics.py @@ -0,0 +1,210 @@ +import dataclasses + +import torch +import torch.distributed + +from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base + + +@dataclasses.dataclass +class PolicyGradientMetrics: + """ + Scalar metrics for policy-gradient losses (GRPO, PPO, …). + + All per-token-mean fields use the same normalization as new_logprobs_mean: + sum(value * mask / label_counts.clamp(1)) + The caller must then divide by num_documents_in_batch for the final logged value. + + ratio_sum / ratio_sq_sum are raw masked sums (no label_counts division) for ESS. + + max_advantage / min_advantage are raw per-local-batch extrema; the caller must + all_reduce them with ReduceOp.MAX / ReduceOp.MIN across SDP ranks. + """ + + old_logprobs: torch.Tensor # per-token mean (label_counts normalised) + ratio: torch.Tensor # per-token mean IS ratio + ratio_sum: torch.Tensor # raw masked sum (ESS numerator) + ratio_sq_sum: torch.Tensor # raw masked sum (ESS denominator) + kl_new_old: torch.Tensor # per-token mean Schulman KL approx + clamp_frac: torch.Tensor # per-token mean clipping indicator + advantage: torch.Tensor # per-token mean + max_advantage: torch.Tensor # max over masked tokens (caller does MAX all-reduce) + min_advantage: torch.Tensor # min over masked tokens (caller does MIN all-reduce) + num_tokens: torch.Tensor # raw masked sum + entropy: torch.Tensor | None # per-token mean entropy; None when not requested + + +@torch.compile +def _compute_pg_base_metrics( + logits: torch.Tensor, # (*batch, vocab_local) + target: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + label_counts: torch.Tensor, # (*batch,) global per-seq count, broadcast per token + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + group: torch.distributed.ProcessGroup | None, +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + """Compute all non-entropy policy-gradient metrics in a single fused pass.""" + loss_mask = target >= 0 + mask = loss_mask.float() + denom = label_counts.float().clamp(min=1) + + logits_norm, _, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) + new_log_probs = predicted_logits - sum_exp_logits.log() + + log_ratio = new_log_probs - old_log_probabilities + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) + + # Schulman KL approximation: exp(r) - r - 1 + kl = ratio - log_ratio - 1.0 + + old_lp = (old_log_probabilities * mask / denom).sum() + ratio_mean = (ratio * mask / denom).sum() + ratio_sum = (ratio * mask).sum() + ratio_sq_sum = (ratio * ratio * mask).sum() + kl_mean = (kl * mask / denom).sum() + clamp_mean = (clipped.float() * mask / denom).sum() + adv_mean = (advantages * mask / denom).sum() + num_tokens = mask.sum() + + # max/min over masked positions; fill non-masked with sentinel values + neg_inf = advantages.new_full((), float("-inf")) + pos_inf = advantages.new_full((), float("inf")) + max_adv = torch.where(loss_mask, advantages, neg_inf).max() + min_adv = torch.where(loss_mask, advantages, pos_inf).min() + + return old_lp, ratio_mean, ratio_sum, ratio_sq_sum, kl_mean, clamp_mean, adv_mean, max_adv, min_adv, num_tokens + + +def compute_chunked_entropy( + logits: torch.Tensor, # (*batch, vocab_local) + target: torch.Tensor, # (*batch,) — used only for loss_mask + label_counts: torch.Tensor, # (*batch,) + logits_scale_factor: float, + group: torch.distributed.ProcessGroup | None, + chunk_size: int = 4096, +) -> torch.Tensor: + """ + Compute per-token entropy -Σ p log p, chunked over the batch dimension to + limit peak memory. Supports vocab-parallel via all-reduce per chunk. + + Returns a scalar using the same label_counts normalisation as other mean metrics + (sum of per-sequence mean entropies). Caller must divide by num_documents_in_batch. + + Memory per chunk: chunk_size × vocab_local × 4 bytes. + At chunk_size=4096, vocab_local=19K (8-way TP): ~300 MB. + + Entropy formula (numerically stable): + entropy_i = log(Σ exp(x_j - x_max)) - Σ(exp(x_j - x_max) * (x_j - x_max)) / Σ exp(x_j - x_max) + = log(sum_exp) - (exp_logits · logits_norm).sum() / sum_exp + """ + loss_mask = target >= 0 + mask = loss_mask.float() + denom = label_counts.float().clamp(min=1) + + batch_size = logits.shape[0] + total = logits.new_zeros(()) + + for start in range(0, batch_size, chunk_size): + sl = slice(start, start + chunk_size) + logits_chunk = logits[sl] + + # Recompute softmax base for this chunk only. + # Scale here since fused_softmax_base expects the full tensor for max/all-reduce; + # we handle it manually to avoid a full-tensor pass. + if logits_scale_factor != 1.0: + logits_chunk = logits_chunk * logits_scale_factor + + logits_max = logits_chunk.float().max(dim=-1).values + if group is not None: + torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=group) + + logits_norm_chunk = logits_chunk.float() - logits_max.unsqueeze(-1) + exp_chunk = logits_norm_chunk.exp() + sum_exp_chunk = exp_chunk.sum(dim=-1) + if group is not None: + torch.distributed.all_reduce(sum_exp_chunk, op=torch.distributed.ReduceOp.SUM, group=group) + + # entropy_i = log(sum_exp) - (exp · logits_norm).sum(-1) / sum_exp + entropy_chunk = sum_exp_chunk.log() - (exp_chunk * logits_norm_chunk).sum(-1) / sum_exp_chunk + + total = total + (entropy_chunk * mask[sl] / denom[sl]).sum() + + return total + + +def compute_policy_gradient_metrics( + logits: torch.Tensor, + target: torch.Tensor, + old_log_probabilities: torch.Tensor, + advantages: torch.Tensor, + label_counts: torch.Tensor, + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + vocab_parallel_group: torch.distributed.ProcessGroup | None, + compute_entropy: bool = False, + entropy_chunk_size: int = 4096, +) -> PolicyGradientMetrics: + ( + old_lp, + ratio_mean, + ratio_sum, + ratio_sq_sum, + kl_mean, + clamp_mean, + adv_mean, + max_adv, + min_adv, + num_tokens, + ) = _compute_pg_base_metrics( + logits, + target, + old_log_probabilities, + advantages, + label_counts, + epsilon_low, + epsilon_high, + logits_scale_factor, + vocab_parallel_group, + ) + + entropy = None + if compute_entropy: + entropy = compute_chunked_entropy( + logits, + target, + label_counts, + logits_scale_factor, + vocab_parallel_group, + entropy_chunk_size, + ) + + return PolicyGradientMetrics( + old_logprobs=old_lp, + ratio=ratio_mean, + ratio_sum=ratio_sum, + ratio_sq_sum=ratio_sq_sum, + kl_new_old=kl_mean, + clamp_frac=clamp_mean, + advantage=adv_mean, + max_advantage=max_adv, + min_advantage=min_adv, + num_tokens=num_tokens, + entropy=entropy, + ) diff --git a/tests/layers/test_grpo_metrics.py b/tests/layers/test_grpo_metrics.py new file mode 100644 index 000000000..f3ac4c5fe --- /dev/null +++ b/tests/layers/test_grpo_metrics.py @@ -0,0 +1,377 @@ +""" +Unit tests for pg_metrics.py — PolicyGradientMetrics computation. + +All tests run on CPU (or GPU if available) without distributed communication +(vocab_parallel_group=None). Distributed reduction is exercised conceptually +via the mock-SDP and mock-vocab-parallel sections. +""" + +import math + +import torch + +from fast_llm.layers.language_model.loss.pg_metrics import ( + compute_chunked_entropy, + compute_policy_gradient_metrics, +) + +# --------------------------------------------------------------------------- +# helpers +# --------------------------------------------------------------------------- + +device = "cuda" if torch.cuda.is_available() else "cpu" + + +def _manual_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo, eps_hi): + """Reference implementation (pure PyTorch, no compilation).""" + loss_mask = target >= 0 + mask = loss_mask.float() + denom = label_counts.float().clamp(min=1) + + log_softmax = torch.log_softmax(logits.float(), dim=-1) + new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) + + log_ratio = new_log_probs - old_log_probs.float() + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - eps_lo) | (ratio > 1.0 + eps_hi) + kl = ratio - log_ratio - 1.0 + + old_lp = (old_log_probs.float() * mask / denom).sum() + ratio_mean = (ratio * mask / denom).sum() + ratio_sum = (ratio * mask).sum() + ratio_sq_sum = (ratio * ratio * mask).sum() + kl_mean = (kl * mask / denom).sum() + clamp_mean = (clipped.float() * mask / denom).sum() + adv_mean = (advantages.float() * mask / denom).sum() + max_adv = advantages.float()[loss_mask].max() + min_adv = advantages.float()[loss_mask].min() + num_tokens = mask.sum() + + probs = log_softmax.exp() + entropy_per_token = -(probs * log_softmax).sum(-1) + entropy_mean = (entropy_per_token * mask / denom).sum() + + return dict( + old_logprobs=old_lp, + ratio=ratio_mean, + ratio_sum=ratio_sum, + ratio_sq_sum=ratio_sq_sum, + kl_new_old=kl_mean, + clamp_frac=clamp_mean, + advantage=adv_mean, + max_advantage=max_adv, + min_advantage=min_adv, + num_tokens=num_tokens, + entropy=entropy_mean, + ) + + +def _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.2, eps_hi=0.2, chunk_size=4096): + return compute_policy_gradient_metrics( + logits, + target, + old_log_probs, + advantages, + label_counts, + eps_lo, + eps_hi, + logits_scale_factor=1.0, + vocab_parallel_group=None, + compute_entropy=True, + entropy_chunk_size=chunk_size, + ) + + +def _assert_close(a, b, msg="", atol=1e-5): + assert abs(a.item() - b.item()) < atol, f"{msg}: got {a.item():.8f}, expected {b.item():.8f}" + + +# --------------------------------------------------------------------------- +# 1. Single sequence — all metrics match manual computation +# --------------------------------------------------------------------------- + + +def test_single_sequence_all_metrics(): + torch.manual_seed(0) + seq_len, vocab = 12, 8 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + old_log_probs = torch.randn(seq_len, device=device) - 3.0 + advantages = torch.randn(seq_len, device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) # all tokens in one seq + + ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) + + for key in ref: + _assert_close(getattr(got, key), ref[key], msg=key) + + +# --------------------------------------------------------------------------- +# 2. Packed multi-sequence — per-sequence normalization +# --------------------------------------------------------------------------- + + +def test_packed_multi_sequence(): + """ + Three sequences of lengths [4, 6, 5] packed into one flat batch (15 tokens). + label_counts broadcasts the global per-sequence count. + """ + torch.manual_seed(1) + lengths = [4, 6, 5] + total = sum(lengths) + vocab = 10 + + logits = torch.randn(total, vocab, device=device) + target = torch.randint(0, vocab, (total,), device=device) + old_log_probs = torch.randn(total, device=device) - 2.0 + advantages = torch.randn(total, device=device) + label_counts = torch.tensor([l for l in lengths for _ in range(l)], dtype=torch.long, device=device) + + ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) + + for key in ref: + _assert_close(getattr(got, key), ref[key], msg=key) + + +# --------------------------------------------------------------------------- +# 3. Masked tokens — masked-out tokens must not contribute +# --------------------------------------------------------------------------- + + +def test_masked_tokens_do_not_contribute(): + """ + A batch where half the tokens are masked (target=-100). + Metrics computed on full batch should equal metrics on unmasked subset only. + """ + torch.manual_seed(2) + seq_len, vocab = 20, 16 + logits = torch.randn(seq_len, vocab, device=device) + target_full = torch.randint(0, vocab, (seq_len,), device=device) + + # mask the first half + mask_bool = torch.ones(seq_len, dtype=torch.bool, device=device) + mask_bool[: seq_len // 2] = False + target_masked = torch.where(mask_bool, target_full, torch.full_like(target_full, -100)) + + old_log_probs = torch.randn(seq_len, device=device) - 2.0 + advantages = torch.randn(seq_len, device=device) + label_counts = torch.full((seq_len,), mask_bool.sum().item(), device=device) + + # reference: only the unmasked slice + half = seq_len // 2 + ref = _manual_metrics( + logits[half:], + target_full[half:], + old_log_probs[half:], + advantages[half:], + label_counts[half:], + 0.2, + 0.2, + ) + got = _run_metrics(logits, target_masked, old_log_probs, advantages, label_counts) + + for key in ref: + _assert_close(getattr(got, key), ref[key], msg=f"masked_{key}") + + +# --------------------------------------------------------------------------- +# 4. Clamp fraction — known ratios → known clamp_frac +# --------------------------------------------------------------------------- + + +def test_clamp_fraction_known(): + """ + Construct logits so that probability_ratio is exactly known. + With eps_lo=0.1, eps_hi=0.1 and 5 tokens: + 2 tokens outside the clip range, 3 inside → clamp_frac = 2/5. + """ + seq_len, vocab = 5, 4 + # uniform logits → probabilities = 1/vocab for any label + logits = torch.zeros(seq_len, vocab, device=device) + target = torch.zeros(seq_len, dtype=torch.long, device=device) # all label=0 + # p_new = 1/4, so new_log_prob = log(0.25) + new_lp = math.log(1.0 / vocab) + + # Set old_log_probs so ratio = exp(new - old) is known per token + # ratios: [0.85, 1.0, 1.05, 1.2, 0.75] (eps=0.1 → clip outside (0.9, 1.1)) + # clipped: True, False, False, True, True → 3 clipped + ratios = torch.tensor([0.85, 1.0, 1.05, 1.2, 0.75], device=device) + old_log_probs = torch.full((seq_len,), new_lp, device=device) - ratios.log() + + advantages = torch.ones(seq_len, device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) + + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.1, eps_hi=0.1) + + expected_clamp_frac = 3.0 / seq_len # 3 out of 5 tokens clipped + _assert_close(got.clamp_frac, torch.tensor(expected_clamp_frac), msg="clamp_frac", atol=1e-5) + + +# --------------------------------------------------------------------------- +# 5. Entropy correctness — small vocab, verify chunked vs reference +# --------------------------------------------------------------------------- + + +def test_entropy_matches_manual(): + """Small vocab so we can compute entropy exactly by hand.""" + torch.manual_seed(3) + seq_len, vocab = 8, 6 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + old_log_probs = torch.randn(seq_len, device=device) - 2.0 + advantages = torch.randn(seq_len, device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) + + # Reference entropy + ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) + + # Test with different chunk sizes (including chunk_size=1 and chunk_size>seq_len) + for chunk_size in (1, 3, seq_len, seq_len + 10): + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, chunk_size=chunk_size) + _assert_close(got.entropy, ref["entropy"], msg=f"entropy chunk_size={chunk_size}") + + +# --------------------------------------------------------------------------- +# 6. Mock SDP — split batch in half, verify sum/max/min consistency +# --------------------------------------------------------------------------- + + +def test_mock_sdp_split(): + """ + Simulate two SDP ranks each holding half the batch. + SUM metrics on full batch == sum of the two halves. + MAX/MIN metrics on full batch == max/min of the two halves. + """ + torch.manual_seed(4) + seq_len, vocab = 18, 12 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + old_log_probs = torch.randn(seq_len, device=device) - 2.0 + advantages = torch.randn(seq_len, device=device) + label_counts = torch.full((seq_len,), seq_len // 2, device=device) + + half = seq_len // 2 + + full = _run_metrics(logits, target, old_log_probs, advantages, label_counts) + lo = _run_metrics(logits[:half], target[:half], old_log_probs[:half], advantages[:half], label_counts[:half]) + hi = _run_metrics(logits[half:], target[half:], old_log_probs[half:], advantages[half:], label_counts[half:]) + + # SUM metrics accumulate across both halves + for attr in ( + "old_logprobs", + "ratio", + "ratio_sum", + "ratio_sq_sum", + "kl_new_old", + "clamp_frac", + "advantage", + "num_tokens", + ): + combined = getattr(lo, attr) + getattr(hi, attr) + _assert_close(getattr(full, attr), combined, msg=f"sdp_{attr}") + + # MAX/MIN are extrema across both halves + _assert_close(full.max_advantage, torch.max(lo.max_advantage, hi.max_advantage), msg="sdp_max_adv") + _assert_close(full.min_advantage, torch.min(lo.min_advantage, hi.min_advantage), msg="sdp_min_adv") + + # Entropy (SUM metric) + _assert_close(full.entropy, lo.entropy + hi.entropy, msg="sdp_entropy") + + +# --------------------------------------------------------------------------- +# 7. Mock vocab-parallel entropy — split logits along vocab dim +# --------------------------------------------------------------------------- + + +def test_mock_vocab_parallel_entropy(): + """ + Simulate 2-way vocab-parallel: split logits along the vocab dim. + Each "rank" computes a partial softmax; the global entropy should + match single-rank computation (all-reduce simulated manually). + """ + torch.manual_seed(5) + seq_len, vocab = 10, 16 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) + mask = torch.ones(seq_len, dtype=torch.bool, device=device) + + # Reference: single rank, full vocab + ref_entropy = compute_chunked_entropy( + logits, + target, + label_counts, + logits_scale_factor=1.0, + group=None, + chunk_size=seq_len, + ) + + # Simulate vocab-parallel: split vocab into [0:8] and [8:16] + # Both ranks see the same sequence but different vocab shards. + # global max is needed for numerical stability: + logits_max = logits.float().max(dim=-1).values # (seq_len,) + + half_v = vocab // 2 + logits_lo = logits[:, :half_v] + logits_hi = logits[:, half_v:] + + # Per rank: compute local sum_exp relative to global max + exp_lo = (logits_lo.float() - logits_max.unsqueeze(-1)).exp() + exp_hi = (logits_hi.float() - logits_max.unsqueeze(-1)).exp() + sum_exp_lo = exp_lo.sum(-1) + sum_exp_hi = exp_hi.sum(-1) + sum_exp_global = sum_exp_lo + sum_exp_hi # simulated SUM all-reduce + + logits_norm_lo = logits_lo.float() - logits_max.unsqueeze(-1) + logits_norm_hi = logits_hi.float() - logits_max.unsqueeze(-1) + + # entropy = log(sum_exp_global) - (exp · logits_norm).sum(-1) / sum_exp_global + dot_lo = (exp_lo * logits_norm_lo).sum(-1) + dot_hi = (exp_hi * logits_norm_hi).sum(-1) + dot_global = dot_lo + dot_hi # simulated SUM all-reduce + + entropy_per_tok = sum_exp_global.log() - dot_global / sum_exp_global + denom = label_counts.float().clamp(min=1) + manual_vp_entropy = (entropy_per_tok * mask.float() / denom).sum() + + _assert_close(ref_entropy, manual_vp_entropy, msg="vocab_parallel_entropy") + + +# --------------------------------------------------------------------------- +# 8. Consistency with new_logprobs_mean normalization +# --------------------------------------------------------------------------- + + +def test_old_logprobs_normalization_matches_new_logprobs_pattern(): + """ + old_logprobs metric uses the same normalization as new_logprobs_mean: + sum(value * mask / label_counts.clamp(1)) + Verify that when old == new (zero perturbation), old_logprobs == new_logprobs_mean. + """ + torch.manual_seed(6) + seq_len, vocab = 14, 20 + logits = torch.randn(seq_len, vocab, device=device) + target = torch.randint(0, vocab, (seq_len,), device=device) + label_counts = torch.full((seq_len,), seq_len, device=device) + + # old_log_probs = actual new_log_probs (no perturbation) + with torch.no_grad(): + new_lp = torch.log_softmax(logits.float(), dim=-1).gather(-1, target.unsqueeze(-1)).squeeze(-1) + + old_log_probs = new_lp.detach() + advantages = torch.randn(seq_len, device=device) + + got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) + + # new_logprobs_mean pattern (from grpo.py fused function) + mask = (target >= 0).float() + denom = label_counts.float().clamp(min=1) + expected_new_lp_mean = (new_lp * mask / denom).sum() + + _assert_close(got.old_logprobs, expected_new_lp_mean, msg="old_logprobs_vs_new_logprobs_mean") + + # ratio should be ~1 everywhere, kl should be ~0 + _assert_close(got.ratio, torch.tensor(1.0) * (mask / denom).sum(), msg="ratio_at_1", atol=1e-4) + _assert_close(got.kl_new_old, torch.zeros(()), msg="kl_at_zero", atol=1e-4) From b856e3971404875fcf7bc3f5bba26bba1e331624 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 27 Apr 2026 16:36:37 +0000 Subject: [PATCH 2/7] grpo: align metric names with DeepSpeed path MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rename four metrics to match DeepSpeed's naming exactly so runs on both backends produce comparable WandB keys: ratio → ratio_new_old ratio_sum → ratio_new_old_sum ratio_sq_sum → ratio_new_old_squared_sum clamp_frac → clamp_log_ratio_new_old_indicator --- fast_llm/layers/language_model/loss/grpo.py | 32 ++++++------- .../layers/language_model/loss/pg_metrics.py | 47 ++++++++++++------- tests/layers/test_grpo_metrics.py | 25 ++++++---- 3 files changed, 60 insertions(+), 44 deletions(-) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 4cb66522c..ab75d2f01 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -84,22 +84,22 @@ def _register_pg_metrics( name = self._name # Per-token mean metrics: divide by num_docs to match new_logprobs_mean normalization. - for attr, suffix in ( - ("old_logprobs", "old_logprobs"), - ("ratio", "ratio"), - ("kl_new_old", "kl_new_old"), - ("clamp_frac", "clamp_frac"), - ("advantage", "advantage"), + for attr in ( + "old_logprobs", + "ratio_new_old", + "kl_new_old", + "clamp_log_ratio_new_old_indicator", + "advantage", ): - self._register_loss(f"{name}_{suffix}", getattr(metrics, attr) / num_docs, losses) + self._register_loss(f"{name}_{attr}", getattr(metrics, attr) / num_docs, losses) # Raw sum metrics (no per-doc normalization). - for attr, suffix in ( - ("ratio_sum", "ratio_sum"), - ("ratio_sq_sum", "ratio_sq_sum"), - ("num_tokens", "num_tokens"), + for attr in ( + "ratio_new_old_sum", + "ratio_new_old_squared_sum", + "num_tokens", ): - self._register_loss(f"{name}_{suffix}", getattr(metrics, attr), losses) + self._register_loss(f"{name}_{attr}", getattr(metrics, attr), losses) # MAX/MIN metrics: pass correct reduce_op for sequence-parallel mode. self._register_loss( @@ -124,11 +124,11 @@ def get_loss_definitions(self) -> list[LossDef]: name = self._name defs += [ LossDef(f"{name}_old_logprobs"), - LossDef(f"{name}_ratio"), - LossDef(f"{name}_ratio_sum"), - LossDef(f"{name}_ratio_sq_sum"), + LossDef(f"{name}_ratio_new_old"), + LossDef(f"{name}_ratio_new_old_sum"), + LossDef(f"{name}_ratio_new_old_squared_sum"), LossDef(f"{name}_kl_new_old"), - LossDef(f"{name}_clamp_frac"), + LossDef(f"{name}_clamp_log_ratio_new_old_indicator"), LossDef(f"{name}_advantage"), LossDef(f"{name}_max_advantage", reduction=ReductionType.maximum), LossDef(f"{name}_min_advantage", reduction=ReductionType.minimum), diff --git a/fast_llm/layers/language_model/loss/pg_metrics.py b/fast_llm/layers/language_model/loss/pg_metrics.py index 72c8c811a..1dec3b3ea 100644 --- a/fast_llm/layers/language_model/loss/pg_metrics.py +++ b/fast_llm/layers/language_model/loss/pg_metrics.py @@ -15,18 +15,18 @@ class PolicyGradientMetrics: sum(value * mask / label_counts.clamp(1)) The caller must then divide by num_documents_in_batch for the final logged value. - ratio_sum / ratio_sq_sum are raw masked sums (no label_counts division) for ESS. + ratio_new_old_sum / ratio_new_old_squared_sum are raw masked sums (no label_counts division) for ESS. max_advantage / min_advantage are raw per-local-batch extrema; the caller must all_reduce them with ReduceOp.MAX / ReduceOp.MIN across SDP ranks. """ old_logprobs: torch.Tensor # per-token mean (label_counts normalised) - ratio: torch.Tensor # per-token mean IS ratio - ratio_sum: torch.Tensor # raw masked sum (ESS numerator) - ratio_sq_sum: torch.Tensor # raw masked sum (ESS denominator) + ratio_new_old: torch.Tensor # per-token mean IS ratio + ratio_new_old_sum: torch.Tensor # raw masked sum (ESS numerator) + ratio_new_old_squared_sum: torch.Tensor # raw masked sum (ESS denominator) kl_new_old: torch.Tensor # per-token mean Schulman KL approx - clamp_frac: torch.Tensor # per-token mean clipping indicator + clamp_log_ratio_new_old_indicator: torch.Tensor # per-token mean clipping indicator advantage: torch.Tensor # per-token mean max_advantage: torch.Tensor # max over masked tokens (caller does MAX all-reduce) min_advantage: torch.Tensor # min over masked tokens (caller does MIN all-reduce) @@ -74,11 +74,11 @@ def _compute_pg_base_metrics( kl = ratio - log_ratio - 1.0 old_lp = (old_log_probabilities * mask / denom).sum() - ratio_mean = (ratio * mask / denom).sum() - ratio_sum = (ratio * mask).sum() - ratio_sq_sum = (ratio * ratio * mask).sum() + ratio_new_old_mean = (ratio * mask / denom).sum() + ratio_new_old_sum = (ratio * mask).sum() + ratio_new_old_squared_sum = (ratio * ratio * mask).sum() kl_mean = (kl * mask / denom).sum() - clamp_mean = (clipped.float() * mask / denom).sum() + clamp_indicator_mean = (clipped.float() * mask / denom).sum() adv_mean = (advantages * mask / denom).sum() num_tokens = mask.sum() @@ -88,7 +88,18 @@ def _compute_pg_base_metrics( max_adv = torch.where(loss_mask, advantages, neg_inf).max() min_adv = torch.where(loss_mask, advantages, pos_inf).min() - return old_lp, ratio_mean, ratio_sum, ratio_sq_sum, kl_mean, clamp_mean, adv_mean, max_adv, min_adv, num_tokens + return ( + old_lp, + ratio_new_old_mean, + ratio_new_old_sum, + ratio_new_old_squared_sum, + kl_mean, + clamp_indicator_mean, + adv_mean, + max_adv, + min_adv, + num_tokens, + ) def compute_chunked_entropy( @@ -163,11 +174,11 @@ def compute_policy_gradient_metrics( ) -> PolicyGradientMetrics: ( old_lp, - ratio_mean, - ratio_sum, - ratio_sq_sum, + ratio_new_old_mean, + ratio_new_old_sum, + ratio_new_old_squared_sum, kl_mean, - clamp_mean, + clamp_indicator_mean, adv_mean, max_adv, min_adv, @@ -197,11 +208,11 @@ def compute_policy_gradient_metrics( return PolicyGradientMetrics( old_logprobs=old_lp, - ratio=ratio_mean, - ratio_sum=ratio_sum, - ratio_sq_sum=ratio_sq_sum, + ratio_new_old=ratio_new_old_mean, + ratio_new_old_sum=ratio_new_old_sum, + ratio_new_old_squared_sum=ratio_new_old_squared_sum, kl_new_old=kl_mean, - clamp_frac=clamp_mean, + clamp_log_ratio_new_old_indicator=clamp_indicator_mean, advantage=adv_mean, max_advantage=max_adv, min_advantage=min_adv, diff --git a/tests/layers/test_grpo_metrics.py b/tests/layers/test_grpo_metrics.py index f3ac4c5fe..1406fa514 100644 --- a/tests/layers/test_grpo_metrics.py +++ b/tests/layers/test_grpo_metrics.py @@ -53,11 +53,11 @@ def _manual_metrics(logits, target, old_log_probs, advantages, label_counts, eps return dict( old_logprobs=old_lp, - ratio=ratio_mean, - ratio_sum=ratio_sum, - ratio_sq_sum=ratio_sq_sum, + ratio_new_old=ratio_mean, + ratio_new_old_sum=ratio_sum, + ratio_new_old_squared_sum=ratio_sq_sum, kl_new_old=kl_mean, - clamp_frac=clamp_mean, + clamp_log_ratio_new_old_indicator=clamp_mean, advantage=adv_mean, max_advantage=max_adv, min_advantage=min_adv, @@ -206,7 +206,12 @@ def test_clamp_fraction_known(): got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.1, eps_hi=0.1) expected_clamp_frac = 3.0 / seq_len # 3 out of 5 tokens clipped - _assert_close(got.clamp_frac, torch.tensor(expected_clamp_frac), msg="clamp_frac", atol=1e-5) + _assert_close( + got.clamp_log_ratio_new_old_indicator, + torch.tensor(expected_clamp_frac), + msg="clamp_log_ratio_new_old_indicator", + atol=1e-5, + ) # --------------------------------------------------------------------------- @@ -261,11 +266,11 @@ def test_mock_sdp_split(): # SUM metrics accumulate across both halves for attr in ( "old_logprobs", - "ratio", - "ratio_sum", - "ratio_sq_sum", + "ratio_new_old", + "ratio_new_old_sum", + "ratio_new_old_squared_sum", "kl_new_old", - "clamp_frac", + "clamp_log_ratio_new_old_indicator", "advantage", "num_tokens", ): @@ -373,5 +378,5 @@ def test_old_logprobs_normalization_matches_new_logprobs_pattern(): _assert_close(got.old_logprobs, expected_new_lp_mean, msg="old_logprobs_vs_new_logprobs_mean") # ratio should be ~1 everywhere, kl should be ~0 - _assert_close(got.ratio, torch.tensor(1.0) * (mask / denom).sum(), msg="ratio_at_1", atol=1e-4) + _assert_close(got.ratio_new_old, torch.tensor(1.0) * (mask / denom).sum(), msg="ratio_new_old_at_1", atol=1e-4) _assert_close(got.kl_new_old, torch.zeros(()), msg="kl_at_zero", atol=1e-4) From d360a46bcc1b9629d83e667ca5babda3c1c53222 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 12:34:44 -0400 Subject: [PATCH 3/7] grpo: address review feedback on metrics - Inline pg_metrics.py into grpo.py; rename to GRPOMetrics - Drop entropy_chunk_size; reuse fused_softmax_base outputs for entropy - Replace two bool flags with a single metrics: GRPOMetricsLevel enum - Rename clamp_log_ratio_new_old_indicator -> clipped_ratio_fraction - Raise on metrics enabled with pipeline_parallel > 1 (MAX/MIN reduce would be corrupted by the zero placeholder on empty pipeline ranks) - Migrate tests into tests/layers/test_lm_losses.py, reusing the existing helpers and parametrization (single + distributed runner) Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 28 +- fast_llm/layers/language_model/loss/grpo.py | 116 +++++- .../layers/language_model/loss/pg_metrics.py | 221 ---------- tests/layers/test_grpo_metrics.py | 382 ------------------ tests/layers/test_lm_losses.py | 124 +++++- 5 files changed, 237 insertions(+), 634 deletions(-) delete mode 100644 fast_llm/layers/language_model/loss/pg_metrics.py delete mode 100644 tests/layers/test_grpo_metrics.py diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 4f91724a2..2c27d2e65 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -1,3 +1,4 @@ +import enum import typing import warnings @@ -193,6 +194,12 @@ def loss_class(self) -> "type[LanguageModelZLoss]": return LanguageModelZLoss +class GRPOMetricsLevel(enum.StrEnum): + none = "none" + basic = "basic" + with_entropy = "with_entropy" + + @config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) class LanguageModelGRPOLossConfig(LanguageModelLossConfig): @@ -205,21 +212,16 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc="Enable triton implementation. Default: use if available.", hint=FieldHint.expert, ) - compute_extra_metrics: bool = Field( - default=False, - desc="Log additional GRPO metrics: old_logprobs, ratio, KL(new||old), advantage stats, clamp fraction, token count.", - hint=FieldHint.feature, - ) - compute_entropy_metric: bool = Field( - default=False, - desc="Also log per-token entropy (-Σ p log p). Requires a second pass over logits (~10-20%% overhead). Implies compute_extra_metrics.", + metrics: GRPOMetricsLevel = Field( + default=GRPOMetricsLevel.none, + desc=( + "Additional GRPO metrics to log. " + "`basic`: old_logprobs, ratio, KL(new||old), advantage stats, clipped fraction, token count. " + "`with_entropy`: also log per-token entropy (-Σ p log p; ~10-20%% overhead from a second softmax pass). " + "Not supported with pipeline_parallel > 1." + ), hint=FieldHint.feature, ) - entropy_chunk_size: int = Field( - default=4096, - desc="Batch chunk size for chunked entropy computation. Memory per chunk ∝ chunk_size × vocab_local.", - hint=FieldHint.expert, - ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index ab75d2f01..745f7abb6 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -1,18 +1,55 @@ +import dataclasses import functools import typing import torch from fast_llm.engine.base_model.config import LossDef, ReductionType +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.config import ( + GRPOMetricsLevel, + LanguageModelGRPOLossConfig, + LanguageModelLossKwargs, +) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +@dataclasses.dataclass +class GRPOMetrics: + old_logprobs: torch.Tensor + ratio_new_old: torch.Tensor + ratio_new_old_sum: torch.Tensor + ratio_new_old_squared_sum: torch.Tensor + kl_new_old: torch.Tensor + clipped_ratio_fraction: torch.Tensor + advantage: torch.Tensor + max_advantage: torch.Tensor + min_advantage: torch.Tensor + num_tokens: torch.Tensor + entropy: torch.Tensor | None + + class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + **kwargs, + ): + super().__init__(config, distributed_config, **kwargs) + # MAX/MIN reductions are unsafe under pipeline parallelism: ranks without this loss layer + # contribute a torch.zeros([1]) placeholder in LossDef.reduce, which corrupts the extremum + # whenever the real value has the opposite sign. + if config.metrics != GRPOMetricsLevel.none and distributed_config.pipeline_parallel > 1: + raise NotImplementedError( + "GRPO extra metrics are not supported with pipeline_parallel > 1 " + "(MAX/MIN advantage reductions would be corrupted by the zero placeholder on empty pipeline ranks)." + ) + def _forward_backward( self, logits: "torch.Tensor", @@ -52,21 +89,19 @@ def _forward_backward( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) - if losses is not None and (self._config.compute_extra_metrics or self._config.compute_entropy_metric): - self._register_pg_metrics(logits, kwargs, losses, split_index) + if losses is not None and self._config.metrics != GRPOMetricsLevel.none: + self._register_extra_metrics(logits, kwargs, losses, split_index) return loss, grad - def _register_pg_metrics( + def _register_extra_metrics( self, logits: torch.Tensor, kwargs: dict[str, typing.Any], losses: dict, split_index: int, ) -> None: - from fast_llm.layers.language_model.loss.pg_metrics import compute_policy_gradient_metrics - - metrics = compute_policy_gradient_metrics( + metrics = compute_grpo_metrics( logits, self._get_labels(kwargs, split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), @@ -75,25 +110,22 @@ def _register_pg_metrics( self._config.epsilon_low, self._config.epsilon_high, self._logits_scale_factor, - vocab_parallel_group=self._parallel_dim.group if self._vocab_parallel else None, - compute_entropy=self._config.compute_entropy_metric, - entropy_chunk_size=self._config.entropy_chunk_size, + group=self._parallel_dim.group if self._vocab_parallel else None, + compute_entropy=self._config.metrics == GRPOMetricsLevel.with_entropy, ) num_docs = kwargs[LanguageModelKwargs.num_documents_in_batch] name = self._name - # Per-token mean metrics: divide by num_docs to match new_logprobs_mean normalization. for attr in ( "old_logprobs", "ratio_new_old", "kl_new_old", - "clamp_log_ratio_new_old_indicator", + "clipped_ratio_fraction", "advantage", ): self._register_loss(f"{name}_{attr}", getattr(metrics, attr) / num_docs, losses) - # Raw sum metrics (no per-doc normalization). for attr in ( "ratio_new_old_sum", "ratio_new_old_squared_sum", @@ -101,7 +133,6 @@ def _register_pg_metrics( ): self._register_loss(f"{name}_{attr}", getattr(metrics, attr), losses) - # MAX/MIN metrics: pass correct reduce_op for sequence-parallel mode. self._register_loss( f"{name}_max_advantage", metrics.max_advantage, @@ -120,7 +151,7 @@ def _register_pg_metrics( def get_loss_definitions(self) -> list[LossDef]: defs = super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] - if self._config.compute_extra_metrics or self._config.compute_entropy_metric: + if self._config.metrics != GRPOMetricsLevel.none: name = self._name defs += [ LossDef(f"{name}_old_logprobs"), @@ -128,13 +159,13 @@ def get_loss_definitions(self) -> list[LossDef]: LossDef(f"{name}_ratio_new_old_sum"), LossDef(f"{name}_ratio_new_old_squared_sum"), LossDef(f"{name}_kl_new_old"), - LossDef(f"{name}_clamp_log_ratio_new_old_indicator"), + LossDef(f"{name}_clipped_ratio_fraction"), LossDef(f"{name}_advantage"), LossDef(f"{name}_max_advantage", reduction=ReductionType.maximum), LossDef(f"{name}_min_advantage", reduction=ReductionType.minimum), LossDef(f"{name}_num_tokens"), ] - if self._config.compute_entropy_metric: + if self._config.metrics == GRPOMetricsLevel.with_entropy: defs.append(LossDef(f"{name}_entropy")) return defs @@ -148,6 +179,57 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" +@torch.compile +def compute_grpo_metrics( + logits: torch.Tensor, # (*batch, vocab_local) + target: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + label_counts: torch.Tensor, # (*batch,) — global per-sequence count broadcast per token + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + group: torch.distributed.ProcessGroup | None, + compute_entropy: bool, +) -> GRPOMetrics: + loss_mask = target >= 0 + mask = loss_mask.float() + denom = label_counts.float().clamp(min=1) + masked = mask / denom + + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) + new_log_probs = predicted_logits - sum_exp_logits.log() + + log_ratio = new_log_probs - old_log_probabilities + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) + # Schulman k3 KL approximation: exp(r) - r - 1 + kl = ratio - log_ratio - 1.0 + + neg_inf = advantages.new_full((), float("-inf")) + pos_inf = advantages.new_full((), float("inf")) + + entropy = None + if compute_entropy: + entropy_per_token = sum_exp_logits.log() - (exp_logits * logits_norm).sum(-1) / sum_exp_logits + entropy = (entropy_per_token * masked).sum() + + return GRPOMetrics( + old_logprobs=(old_log_probabilities * masked).sum(), + ratio_new_old=(ratio * masked).sum(), + ratio_new_old_sum=(ratio * mask).sum(), + ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), + kl_new_old=(kl * masked).sum(), + clipped_ratio_fraction=(clipped.float() * masked).sum(), + advantage=(advantages * masked).sum(), + max_advantage=torch.where(loss_mask, advantages, neg_inf).max(), + min_advantage=torch.where(loss_mask, advantages, pos_inf).min(), + num_tokens=mask.sum(), + entropy=entropy, + ) + + @torch.compile def fused_grpo_loss_forward_backward( logits: torch.Tensor, # (*batch, vocab) diff --git a/fast_llm/layers/language_model/loss/pg_metrics.py b/fast_llm/layers/language_model/loss/pg_metrics.py deleted file mode 100644 index 1dec3b3ea..000000000 --- a/fast_llm/layers/language_model/loss/pg_metrics.py +++ /dev/null @@ -1,221 +0,0 @@ -import dataclasses - -import torch -import torch.distributed - -from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base - - -@dataclasses.dataclass -class PolicyGradientMetrics: - """ - Scalar metrics for policy-gradient losses (GRPO, PPO, …). - - All per-token-mean fields use the same normalization as new_logprobs_mean: - sum(value * mask / label_counts.clamp(1)) - The caller must then divide by num_documents_in_batch for the final logged value. - - ratio_new_old_sum / ratio_new_old_squared_sum are raw masked sums (no label_counts division) for ESS. - - max_advantage / min_advantage are raw per-local-batch extrema; the caller must - all_reduce them with ReduceOp.MAX / ReduceOp.MIN across SDP ranks. - """ - - old_logprobs: torch.Tensor # per-token mean (label_counts normalised) - ratio_new_old: torch.Tensor # per-token mean IS ratio - ratio_new_old_sum: torch.Tensor # raw masked sum (ESS numerator) - ratio_new_old_squared_sum: torch.Tensor # raw masked sum (ESS denominator) - kl_new_old: torch.Tensor # per-token mean Schulman KL approx - clamp_log_ratio_new_old_indicator: torch.Tensor # per-token mean clipping indicator - advantage: torch.Tensor # per-token mean - max_advantage: torch.Tensor # max over masked tokens (caller does MAX all-reduce) - min_advantage: torch.Tensor # min over masked tokens (caller does MIN all-reduce) - num_tokens: torch.Tensor # raw masked sum - entropy: torch.Tensor | None # per-token mean entropy; None when not requested - - -@torch.compile -def _compute_pg_base_metrics( - logits: torch.Tensor, # (*batch, vocab_local) - target: torch.Tensor, # (*batch,) - old_log_probabilities: torch.Tensor, # (*batch,) - advantages: torch.Tensor, # (*batch,) - label_counts: torch.Tensor, # (*batch,) global per-seq count, broadcast per token - epsilon_low: float, - epsilon_high: float, - logits_scale_factor: float, - group: torch.distributed.ProcessGroup | None, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, -]: - """Compute all non-entropy policy-gradient metrics in a single fused pass.""" - loss_mask = target >= 0 - mask = loss_mask.float() - denom = label_counts.float().clamp(min=1) - - logits_norm, _, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) - predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) - new_log_probs = predicted_logits - sum_exp_logits.log() - - log_ratio = new_log_probs - old_log_probabilities - ratio = log_ratio.exp() - clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) - - # Schulman KL approximation: exp(r) - r - 1 - kl = ratio - log_ratio - 1.0 - - old_lp = (old_log_probabilities * mask / denom).sum() - ratio_new_old_mean = (ratio * mask / denom).sum() - ratio_new_old_sum = (ratio * mask).sum() - ratio_new_old_squared_sum = (ratio * ratio * mask).sum() - kl_mean = (kl * mask / denom).sum() - clamp_indicator_mean = (clipped.float() * mask / denom).sum() - adv_mean = (advantages * mask / denom).sum() - num_tokens = mask.sum() - - # max/min over masked positions; fill non-masked with sentinel values - neg_inf = advantages.new_full((), float("-inf")) - pos_inf = advantages.new_full((), float("inf")) - max_adv = torch.where(loss_mask, advantages, neg_inf).max() - min_adv = torch.where(loss_mask, advantages, pos_inf).min() - - return ( - old_lp, - ratio_new_old_mean, - ratio_new_old_sum, - ratio_new_old_squared_sum, - kl_mean, - clamp_indicator_mean, - adv_mean, - max_adv, - min_adv, - num_tokens, - ) - - -def compute_chunked_entropy( - logits: torch.Tensor, # (*batch, vocab_local) - target: torch.Tensor, # (*batch,) — used only for loss_mask - label_counts: torch.Tensor, # (*batch,) - logits_scale_factor: float, - group: torch.distributed.ProcessGroup | None, - chunk_size: int = 4096, -) -> torch.Tensor: - """ - Compute per-token entropy -Σ p log p, chunked over the batch dimension to - limit peak memory. Supports vocab-parallel via all-reduce per chunk. - - Returns a scalar using the same label_counts normalisation as other mean metrics - (sum of per-sequence mean entropies). Caller must divide by num_documents_in_batch. - - Memory per chunk: chunk_size × vocab_local × 4 bytes. - At chunk_size=4096, vocab_local=19K (8-way TP): ~300 MB. - - Entropy formula (numerically stable): - entropy_i = log(Σ exp(x_j - x_max)) - Σ(exp(x_j - x_max) * (x_j - x_max)) / Σ exp(x_j - x_max) - = log(sum_exp) - (exp_logits · logits_norm).sum() / sum_exp - """ - loss_mask = target >= 0 - mask = loss_mask.float() - denom = label_counts.float().clamp(min=1) - - batch_size = logits.shape[0] - total = logits.new_zeros(()) - - for start in range(0, batch_size, chunk_size): - sl = slice(start, start + chunk_size) - logits_chunk = logits[sl] - - # Recompute softmax base for this chunk only. - # Scale here since fused_softmax_base expects the full tensor for max/all-reduce; - # we handle it manually to avoid a full-tensor pass. - if logits_scale_factor != 1.0: - logits_chunk = logits_chunk * logits_scale_factor - - logits_max = logits_chunk.float().max(dim=-1).values - if group is not None: - torch.distributed.all_reduce(logits_max, op=torch.distributed.ReduceOp.MAX, group=group) - - logits_norm_chunk = logits_chunk.float() - logits_max.unsqueeze(-1) - exp_chunk = logits_norm_chunk.exp() - sum_exp_chunk = exp_chunk.sum(dim=-1) - if group is not None: - torch.distributed.all_reduce(sum_exp_chunk, op=torch.distributed.ReduceOp.SUM, group=group) - - # entropy_i = log(sum_exp) - (exp · logits_norm).sum(-1) / sum_exp - entropy_chunk = sum_exp_chunk.log() - (exp_chunk * logits_norm_chunk).sum(-1) / sum_exp_chunk - - total = total + (entropy_chunk * mask[sl] / denom[sl]).sum() - - return total - - -def compute_policy_gradient_metrics( - logits: torch.Tensor, - target: torch.Tensor, - old_log_probabilities: torch.Tensor, - advantages: torch.Tensor, - label_counts: torch.Tensor, - epsilon_low: float, - epsilon_high: float, - logits_scale_factor: float, - vocab_parallel_group: torch.distributed.ProcessGroup | None, - compute_entropy: bool = False, - entropy_chunk_size: int = 4096, -) -> PolicyGradientMetrics: - ( - old_lp, - ratio_new_old_mean, - ratio_new_old_sum, - ratio_new_old_squared_sum, - kl_mean, - clamp_indicator_mean, - adv_mean, - max_adv, - min_adv, - num_tokens, - ) = _compute_pg_base_metrics( - logits, - target, - old_log_probabilities, - advantages, - label_counts, - epsilon_low, - epsilon_high, - logits_scale_factor, - vocab_parallel_group, - ) - - entropy = None - if compute_entropy: - entropy = compute_chunked_entropy( - logits, - target, - label_counts, - logits_scale_factor, - vocab_parallel_group, - entropy_chunk_size, - ) - - return PolicyGradientMetrics( - old_logprobs=old_lp, - ratio_new_old=ratio_new_old_mean, - ratio_new_old_sum=ratio_new_old_sum, - ratio_new_old_squared_sum=ratio_new_old_squared_sum, - kl_new_old=kl_mean, - clamp_log_ratio_new_old_indicator=clamp_indicator_mean, - advantage=adv_mean, - max_advantage=max_adv, - min_advantage=min_adv, - num_tokens=num_tokens, - entropy=entropy, - ) diff --git a/tests/layers/test_grpo_metrics.py b/tests/layers/test_grpo_metrics.py deleted file mode 100644 index 1406fa514..000000000 --- a/tests/layers/test_grpo_metrics.py +++ /dev/null @@ -1,382 +0,0 @@ -""" -Unit tests for pg_metrics.py — PolicyGradientMetrics computation. - -All tests run on CPU (or GPU if available) without distributed communication -(vocab_parallel_group=None). Distributed reduction is exercised conceptually -via the mock-SDP and mock-vocab-parallel sections. -""" - -import math - -import torch - -from fast_llm.layers.language_model.loss.pg_metrics import ( - compute_chunked_entropy, - compute_policy_gradient_metrics, -) - -# --------------------------------------------------------------------------- -# helpers -# --------------------------------------------------------------------------- - -device = "cuda" if torch.cuda.is_available() else "cpu" - - -def _manual_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo, eps_hi): - """Reference implementation (pure PyTorch, no compilation).""" - loss_mask = target >= 0 - mask = loss_mask.float() - denom = label_counts.float().clamp(min=1) - - log_softmax = torch.log_softmax(logits.float(), dim=-1) - new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) - - log_ratio = new_log_probs - old_log_probs.float() - ratio = log_ratio.exp() - clipped = (ratio < 1.0 - eps_lo) | (ratio > 1.0 + eps_hi) - kl = ratio - log_ratio - 1.0 - - old_lp = (old_log_probs.float() * mask / denom).sum() - ratio_mean = (ratio * mask / denom).sum() - ratio_sum = (ratio * mask).sum() - ratio_sq_sum = (ratio * ratio * mask).sum() - kl_mean = (kl * mask / denom).sum() - clamp_mean = (clipped.float() * mask / denom).sum() - adv_mean = (advantages.float() * mask / denom).sum() - max_adv = advantages.float()[loss_mask].max() - min_adv = advantages.float()[loss_mask].min() - num_tokens = mask.sum() - - probs = log_softmax.exp() - entropy_per_token = -(probs * log_softmax).sum(-1) - entropy_mean = (entropy_per_token * mask / denom).sum() - - return dict( - old_logprobs=old_lp, - ratio_new_old=ratio_mean, - ratio_new_old_sum=ratio_sum, - ratio_new_old_squared_sum=ratio_sq_sum, - kl_new_old=kl_mean, - clamp_log_ratio_new_old_indicator=clamp_mean, - advantage=adv_mean, - max_advantage=max_adv, - min_advantage=min_adv, - num_tokens=num_tokens, - entropy=entropy_mean, - ) - - -def _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.2, eps_hi=0.2, chunk_size=4096): - return compute_policy_gradient_metrics( - logits, - target, - old_log_probs, - advantages, - label_counts, - eps_lo, - eps_hi, - logits_scale_factor=1.0, - vocab_parallel_group=None, - compute_entropy=True, - entropy_chunk_size=chunk_size, - ) - - -def _assert_close(a, b, msg="", atol=1e-5): - assert abs(a.item() - b.item()) < atol, f"{msg}: got {a.item():.8f}, expected {b.item():.8f}" - - -# --------------------------------------------------------------------------- -# 1. Single sequence — all metrics match manual computation -# --------------------------------------------------------------------------- - - -def test_single_sequence_all_metrics(): - torch.manual_seed(0) - seq_len, vocab = 12, 8 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - old_log_probs = torch.randn(seq_len, device=device) - 3.0 - advantages = torch.randn(seq_len, device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) # all tokens in one seq - - ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) - - for key in ref: - _assert_close(getattr(got, key), ref[key], msg=key) - - -# --------------------------------------------------------------------------- -# 2. Packed multi-sequence — per-sequence normalization -# --------------------------------------------------------------------------- - - -def test_packed_multi_sequence(): - """ - Three sequences of lengths [4, 6, 5] packed into one flat batch (15 tokens). - label_counts broadcasts the global per-sequence count. - """ - torch.manual_seed(1) - lengths = [4, 6, 5] - total = sum(lengths) - vocab = 10 - - logits = torch.randn(total, vocab, device=device) - target = torch.randint(0, vocab, (total,), device=device) - old_log_probs = torch.randn(total, device=device) - 2.0 - advantages = torch.randn(total, device=device) - label_counts = torch.tensor([l for l in lengths for _ in range(l)], dtype=torch.long, device=device) - - ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) - - for key in ref: - _assert_close(getattr(got, key), ref[key], msg=key) - - -# --------------------------------------------------------------------------- -# 3. Masked tokens — masked-out tokens must not contribute -# --------------------------------------------------------------------------- - - -def test_masked_tokens_do_not_contribute(): - """ - A batch where half the tokens are masked (target=-100). - Metrics computed on full batch should equal metrics on unmasked subset only. - """ - torch.manual_seed(2) - seq_len, vocab = 20, 16 - logits = torch.randn(seq_len, vocab, device=device) - target_full = torch.randint(0, vocab, (seq_len,), device=device) - - # mask the first half - mask_bool = torch.ones(seq_len, dtype=torch.bool, device=device) - mask_bool[: seq_len // 2] = False - target_masked = torch.where(mask_bool, target_full, torch.full_like(target_full, -100)) - - old_log_probs = torch.randn(seq_len, device=device) - 2.0 - advantages = torch.randn(seq_len, device=device) - label_counts = torch.full((seq_len,), mask_bool.sum().item(), device=device) - - # reference: only the unmasked slice - half = seq_len // 2 - ref = _manual_metrics( - logits[half:], - target_full[half:], - old_log_probs[half:], - advantages[half:], - label_counts[half:], - 0.2, - 0.2, - ) - got = _run_metrics(logits, target_masked, old_log_probs, advantages, label_counts) - - for key in ref: - _assert_close(getattr(got, key), ref[key], msg=f"masked_{key}") - - -# --------------------------------------------------------------------------- -# 4. Clamp fraction — known ratios → known clamp_frac -# --------------------------------------------------------------------------- - - -def test_clamp_fraction_known(): - """ - Construct logits so that probability_ratio is exactly known. - With eps_lo=0.1, eps_hi=0.1 and 5 tokens: - 2 tokens outside the clip range, 3 inside → clamp_frac = 2/5. - """ - seq_len, vocab = 5, 4 - # uniform logits → probabilities = 1/vocab for any label - logits = torch.zeros(seq_len, vocab, device=device) - target = torch.zeros(seq_len, dtype=torch.long, device=device) # all label=0 - # p_new = 1/4, so new_log_prob = log(0.25) - new_lp = math.log(1.0 / vocab) - - # Set old_log_probs so ratio = exp(new - old) is known per token - # ratios: [0.85, 1.0, 1.05, 1.2, 0.75] (eps=0.1 → clip outside (0.9, 1.1)) - # clipped: True, False, False, True, True → 3 clipped - ratios = torch.tensor([0.85, 1.0, 1.05, 1.2, 0.75], device=device) - old_log_probs = torch.full((seq_len,), new_lp, device=device) - ratios.log() - - advantages = torch.ones(seq_len, device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) - - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, eps_lo=0.1, eps_hi=0.1) - - expected_clamp_frac = 3.0 / seq_len # 3 out of 5 tokens clipped - _assert_close( - got.clamp_log_ratio_new_old_indicator, - torch.tensor(expected_clamp_frac), - msg="clamp_log_ratio_new_old_indicator", - atol=1e-5, - ) - - -# --------------------------------------------------------------------------- -# 5. Entropy correctness — small vocab, verify chunked vs reference -# --------------------------------------------------------------------------- - - -def test_entropy_matches_manual(): - """Small vocab so we can compute entropy exactly by hand.""" - torch.manual_seed(3) - seq_len, vocab = 8, 6 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - old_log_probs = torch.randn(seq_len, device=device) - 2.0 - advantages = torch.randn(seq_len, device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) - - # Reference entropy - ref = _manual_metrics(logits, target, old_log_probs, advantages, label_counts, 0.2, 0.2) - - # Test with different chunk sizes (including chunk_size=1 and chunk_size>seq_len) - for chunk_size in (1, 3, seq_len, seq_len + 10): - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts, chunk_size=chunk_size) - _assert_close(got.entropy, ref["entropy"], msg=f"entropy chunk_size={chunk_size}") - - -# --------------------------------------------------------------------------- -# 6. Mock SDP — split batch in half, verify sum/max/min consistency -# --------------------------------------------------------------------------- - - -def test_mock_sdp_split(): - """ - Simulate two SDP ranks each holding half the batch. - SUM metrics on full batch == sum of the two halves. - MAX/MIN metrics on full batch == max/min of the two halves. - """ - torch.manual_seed(4) - seq_len, vocab = 18, 12 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - old_log_probs = torch.randn(seq_len, device=device) - 2.0 - advantages = torch.randn(seq_len, device=device) - label_counts = torch.full((seq_len,), seq_len // 2, device=device) - - half = seq_len // 2 - - full = _run_metrics(logits, target, old_log_probs, advantages, label_counts) - lo = _run_metrics(logits[:half], target[:half], old_log_probs[:half], advantages[:half], label_counts[:half]) - hi = _run_metrics(logits[half:], target[half:], old_log_probs[half:], advantages[half:], label_counts[half:]) - - # SUM metrics accumulate across both halves - for attr in ( - "old_logprobs", - "ratio_new_old", - "ratio_new_old_sum", - "ratio_new_old_squared_sum", - "kl_new_old", - "clamp_log_ratio_new_old_indicator", - "advantage", - "num_tokens", - ): - combined = getattr(lo, attr) + getattr(hi, attr) - _assert_close(getattr(full, attr), combined, msg=f"sdp_{attr}") - - # MAX/MIN are extrema across both halves - _assert_close(full.max_advantage, torch.max(lo.max_advantage, hi.max_advantage), msg="sdp_max_adv") - _assert_close(full.min_advantage, torch.min(lo.min_advantage, hi.min_advantage), msg="sdp_min_adv") - - # Entropy (SUM metric) - _assert_close(full.entropy, lo.entropy + hi.entropy, msg="sdp_entropy") - - -# --------------------------------------------------------------------------- -# 7. Mock vocab-parallel entropy — split logits along vocab dim -# --------------------------------------------------------------------------- - - -def test_mock_vocab_parallel_entropy(): - """ - Simulate 2-way vocab-parallel: split logits along the vocab dim. - Each "rank" computes a partial softmax; the global entropy should - match single-rank computation (all-reduce simulated manually). - """ - torch.manual_seed(5) - seq_len, vocab = 10, 16 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) - mask = torch.ones(seq_len, dtype=torch.bool, device=device) - - # Reference: single rank, full vocab - ref_entropy = compute_chunked_entropy( - logits, - target, - label_counts, - logits_scale_factor=1.0, - group=None, - chunk_size=seq_len, - ) - - # Simulate vocab-parallel: split vocab into [0:8] and [8:16] - # Both ranks see the same sequence but different vocab shards. - # global max is needed for numerical stability: - logits_max = logits.float().max(dim=-1).values # (seq_len,) - - half_v = vocab // 2 - logits_lo = logits[:, :half_v] - logits_hi = logits[:, half_v:] - - # Per rank: compute local sum_exp relative to global max - exp_lo = (logits_lo.float() - logits_max.unsqueeze(-1)).exp() - exp_hi = (logits_hi.float() - logits_max.unsqueeze(-1)).exp() - sum_exp_lo = exp_lo.sum(-1) - sum_exp_hi = exp_hi.sum(-1) - sum_exp_global = sum_exp_lo + sum_exp_hi # simulated SUM all-reduce - - logits_norm_lo = logits_lo.float() - logits_max.unsqueeze(-1) - logits_norm_hi = logits_hi.float() - logits_max.unsqueeze(-1) - - # entropy = log(sum_exp_global) - (exp · logits_norm).sum(-1) / sum_exp_global - dot_lo = (exp_lo * logits_norm_lo).sum(-1) - dot_hi = (exp_hi * logits_norm_hi).sum(-1) - dot_global = dot_lo + dot_hi # simulated SUM all-reduce - - entropy_per_tok = sum_exp_global.log() - dot_global / sum_exp_global - denom = label_counts.float().clamp(min=1) - manual_vp_entropy = (entropy_per_tok * mask.float() / denom).sum() - - _assert_close(ref_entropy, manual_vp_entropy, msg="vocab_parallel_entropy") - - -# --------------------------------------------------------------------------- -# 8. Consistency with new_logprobs_mean normalization -# --------------------------------------------------------------------------- - - -def test_old_logprobs_normalization_matches_new_logprobs_pattern(): - """ - old_logprobs metric uses the same normalization as new_logprobs_mean: - sum(value * mask / label_counts.clamp(1)) - Verify that when old == new (zero perturbation), old_logprobs == new_logprobs_mean. - """ - torch.manual_seed(6) - seq_len, vocab = 14, 20 - logits = torch.randn(seq_len, vocab, device=device) - target = torch.randint(0, vocab, (seq_len,), device=device) - label_counts = torch.full((seq_len,), seq_len, device=device) - - # old_log_probs = actual new_log_probs (no perturbation) - with torch.no_grad(): - new_lp = torch.log_softmax(logits.float(), dim=-1).gather(-1, target.unsqueeze(-1)).squeeze(-1) - - old_log_probs = new_lp.detach() - advantages = torch.randn(seq_len, device=device) - - got = _run_metrics(logits, target, old_log_probs, advantages, label_counts) - - # new_logprobs_mean pattern (from grpo.py fused function) - mask = (target >= 0).float() - denom = label_counts.float().clamp(min=1) - expected_new_lp_mean = (new_lp * mask / denom).sum() - - _assert_close(got.old_logprobs, expected_new_lp_mean, msg="old_logprobs_vs_new_logprobs_mean") - - # ratio should be ~1 everywhere, kl should be ~0 - _assert_close(got.ratio_new_old, torch.tensor(1.0) * (mask / denom).sum(), msg="ratio_new_old_at_1", atol=1e-4) - _assert_close(got.kl_new_old, torch.zeros(()), msg="kl_at_zero", atol=1e-4) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 9b93aeb66..e24b3236e 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -16,7 +16,7 @@ from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss -from fast_llm.layers.language_model.loss.grpo import fused_grpo_loss_forward_backward +from fast_llm.layers.language_model.loss.grpo import compute_grpo_metrics, fused_grpo_loss_forward_backward from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import fused_z_loss_forward_backward, z_loss from fast_llm.utils import Assert @@ -121,6 +121,47 @@ def reference_dpo_loss( return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() +def reference_grpo_metrics( + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + label_counts: torch.Tensor, + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + compute_entropy: bool, +) -> dict[str, torch.Tensor]: + log_softmax = torch.nn.functional.log_softmax(logits.float() * logits_scale_factor, dim=-1) + loss_mask = target >= 0 + mask = loss_mask.float() + masked = mask / label_counts.float().clamp(min=1) + + new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) + log_ratio = new_log_probs - old_log_probabilities.float() + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) + kl = ratio - log_ratio - 1.0 + + metrics = { + "old_logprobs": (old_log_probabilities.float() * masked).sum(), + "ratio_new_old": (ratio * masked).sum(), + "ratio_new_old_sum": (ratio * mask).sum(), + "ratio_new_old_squared_sum": (ratio * ratio * mask).sum(), + "kl_new_old": (kl * masked).sum(), + "clipped_ratio_fraction": (clipped.float() * masked).sum(), + "advantage": (advantages.float() * masked).sum(), + "max_advantage": torch.where(loss_mask, advantages, advantages.new_full((), float("-inf"))).max(), + "min_advantage": torch.where(loss_mask, advantages, advantages.new_full((), float("inf"))).min(), + "num_tokens": mask.sum(), + "entropy": None, + } + if compute_entropy: + entropy_per_token = -(log_softmax.exp() * log_softmax).sum(-1) + metrics["entropy"] = (entropy_per_token * masked).sum() + return metrics + + def reference_grpo_loss( logits: torch.Tensor, labels: torch.Tensor, @@ -304,6 +345,50 @@ def _test_grpo_loss( Assert.rms_close_relative(new_logprobs_triton, new_logprobs_fused, 1e-5, 1e-6) +def _test_grpo_metrics( + batch_shape, num_columns, logits_scale_factor, loss_masking, dtype, compute_entropy, group=None +): + logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( + num_columns, loss_masking, batch_shape, dtype + ) + num_labels = max(int((target >= 0).sum().item()), 1) + label_counts = torch.where( + target >= 0, + torch.full(batch_shape, num_labels, dtype=torch.int32, device=target.device), + torch.zeros(batch_shape, dtype=torch.int32, device=target.device), + ) + + ref = reference_grpo_metrics( + logits, + target, + advantages, + old_log_probabilities, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=logits_scale_factor, + compute_entropy=compute_entropy, + ) + got = compute_grpo_metrics( + split_op(logits, group, -1).contiguous(), + target, + old_log_probabilities, + advantages, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=logits_scale_factor, + group=group, + compute_entropy=compute_entropy, + ) + threshold = 1e-5 if dtype == DataType.float32 else 1e-4 + for key, ref_value in ref.items(): + if ref_value is None: + assert getattr(got, key) is None + else: + Assert.rms_close_relative(getattr(got, key), ref_value, threshold, 1e-6) + + def _test_z_loss( batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, accumulate, group=None ): @@ -421,6 +506,27 @@ def test_grpo_loss( ) +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size", "accumulate"), + _LOSS_PARAMETERS, +) +@pytest.mark.parametrize("compute_entropy", (False, True)) +def test_grpo_metrics( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, + block_size, + accumulate, + compute_entropy, +): + _test_grpo_metrics(batch_shape, num_columns, logits_scale_factor, loss_masking, dtype, compute_entropy) + + @pytest.mark.skip(reason="DPO loss is broken") def test_dpo_loss(): logits = torch.normal(0, 1, (200, 100)) @@ -498,6 +604,20 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa accumulate, test_context.group, ) + # GRPO metrics + for compute_entropy in (False, True): + with test_context.subtest(base_path, f"grpo_metrics-{compute_entropy}-{suffix}", 2) as subtest: + if subtest.do_run: + torch.manual_seed((seed + hash(subtest.name)) % 2**32) + _test_grpo_metrics( + batch_shape, + num_columns, + logits_scale_factor, + loss_masking, + dtype, + compute_entropy, + test_context.group, + ) @pytest.mark.slow @@ -538,6 +658,8 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): ), "z_loss", "grpo", + "grpo_metrics-False", + "grpo_metrics-True", ), ) def test_lm_loss_distributed( From bb6315cb8018b18004efc2514d2617ac3891b866 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 12:49:20 -0400 Subject: [PATCH 4/7] grpo: address review follow-ups - Drop stale "second softmax pass" overhead note from `metrics` description (entropy now reuses the base softmax outputs) - De-mirror max/min in reference_grpo_metrics: use advantages[loss_mask].max()/.min() instead of the implementation's -inf/+inf sentinel pattern Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 2 +- tests/layers/test_lm_losses.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 2c27d2e65..44180404c 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -217,7 +217,7 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc=( "Additional GRPO metrics to log. " "`basic`: old_logprobs, ratio, KL(new||old), advantage stats, clipped fraction, token count. " - "`with_entropy`: also log per-token entropy (-Σ p log p; ~10-20%% overhead from a second softmax pass). " + "`with_entropy`: also log per-token entropy (-Σ p log p). " "Not supported with pipeline_parallel > 1." ), hint=FieldHint.feature, diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index e24b3236e..8b3df6aa3 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -151,8 +151,8 @@ def reference_grpo_metrics( "kl_new_old": (kl * masked).sum(), "clipped_ratio_fraction": (clipped.float() * masked).sum(), "advantage": (advantages.float() * masked).sum(), - "max_advantage": torch.where(loss_mask, advantages, advantages.new_full((), float("-inf"))).max(), - "min_advantage": torch.where(loss_mask, advantages, advantages.new_full((), float("inf"))).min(), + "max_advantage": advantages[loss_mask].max(), + "min_advantage": advantages[loss_mask].min(), "num_tokens": mask.sum(), "entropy": None, } From 89ed06241c302661e6f2114c11db4fbe82c3ac5d Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 13:30:21 -0400 Subject: [PATCH 5/7] grpo: round-3 review fixes - Align (logits, target, advantages, old_log_probabilities, ...) order across compute_grpo_metrics, fused_grpo_loss_forward_backward, and reference_grpo_metrics - Replace **kwargs in LanguageModelGRPOLoss.__init__ with the explicit keyword-only signature mirroring LanguageModelLoss.__init__ - num_docs -> num_documents - Drop the comment that restated the k3 KL formula - Give compute_grpo_metrics the same defaults as the loss kernel - Trim the metrics field description to category-level wording - Always exercise varying label_counts in _test_grpo_metrics so per-token denominator broadcasting is covered - reference_grpo_metrics returns GRPOMetrics; comparison loop iterates dataclasses.fields - Drop name = self._name micro-rebinds; use self._name inline - defs = super()...; defs.append(...); defs.extend(...) consistently - Tighten _register_extra_metrics losses type to dict[str, list[Tensor]] - Split compiled tuple-returning core from outer GRPOMetrics wrapper to avoid @torch.compile graph-breaks on dataclass construction - One-line comment on the metrics gate explaining the softmax-skip Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/config.py | 4 +- fast_llm/layers/language_model/loss/grpo.py | 150 ++++++++++++------ tests/layers/test_lm_losses.py | 69 ++++---- 3 files changed, 146 insertions(+), 77 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 44180404c..70cf8806a 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -216,8 +216,8 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): default=GRPOMetricsLevel.none, desc=( "Additional GRPO metrics to log. " - "`basic`: old_logprobs, ratio, KL(new||old), advantage stats, clipped fraction, token count. " - "`with_entropy`: also log per-token entropy (-Σ p log p). " + "`basic`: per-token ratio, KL, and advantage statistics. " + "`with_entropy`: also log per-token entropy. " "Not supported with pipeline_parallel > 1." ), hint=FieldHint.feature, diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 745f7abb6..8b2ec70c7 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -38,9 +38,28 @@ def __init__( self, config: ConfigType, distributed_config: DistributedConfig, - **kwargs, + *, + name: str, + prediction_distance: int = 1, + prediction_heads: int = 1, + vocab_parallel: bool = False, + num_splits: int = 1, + logits_scale_factor: float = 1.0, + weight: float = 1.0, + register_loss: bool = False, ): - super().__init__(config, distributed_config, **kwargs) + super().__init__( + config, + distributed_config, + name=name, + prediction_distance=prediction_distance, + prediction_heads=prediction_heads, + vocab_parallel=vocab_parallel, + num_splits=num_splits, + logits_scale_factor=logits_scale_factor, + weight=weight, + register_loss=register_loss, + ) # MAX/MIN reductions are unsafe under pipeline parallelism: ranks without this loss layer # contribute a torch.zeros([1]) placeholder in LossDef.reduce, which corrupts the extremum # whenever the real value has the opposite sign. @@ -89,6 +108,7 @@ def _forward_backward( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) + # Skip the extra softmax pass when there is nothing to register. if losses is not None and self._config.metrics != GRPOMetricsLevel.none: self._register_extra_metrics(logits, kwargs, losses, split_index) @@ -98,14 +118,14 @@ def _register_extra_metrics( self, logits: torch.Tensor, kwargs: dict[str, typing.Any], - losses: dict, + losses: dict[str, list[torch.Tensor]], split_index: int, ) -> None: metrics = compute_grpo_metrics( logits, self._get_labels(kwargs, split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index), self._config.epsilon_low, self._config.epsilon_high, @@ -114,8 +134,7 @@ def _register_extra_metrics( compute_entropy=self._config.metrics == GRPOMetricsLevel.with_entropy, ) - num_docs = kwargs[LanguageModelKwargs.num_documents_in_batch] - name = self._name + num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch] for attr in ( "old_logprobs", @@ -124,49 +143,51 @@ def _register_extra_metrics( "clipped_ratio_fraction", "advantage", ): - self._register_loss(f"{name}_{attr}", getattr(metrics, attr) / num_docs, losses) + self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr) / num_documents, losses) for attr in ( "ratio_new_old_sum", "ratio_new_old_squared_sum", "num_tokens", ): - self._register_loss(f"{name}_{attr}", getattr(metrics, attr), losses) + self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr), losses) self._register_loss( - f"{name}_max_advantage", + f"{self._name}_max_advantage", metrics.max_advantage, losses, reduce_op=torch.distributed.ReduceOp.MAX, ) self._register_loss( - f"{name}_min_advantage", + f"{self._name}_min_advantage", metrics.min_advantage, losses, reduce_op=torch.distributed.ReduceOp.MIN, ) if metrics.entropy is not None: - self._register_loss(f"{name}_entropy", metrics.entropy / num_docs, losses) + self._register_loss(f"{self._name}_entropy", metrics.entropy / num_documents, losses) def get_loss_definitions(self) -> list[LossDef]: - defs = super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] + defs = super().get_loss_definitions() + defs.append(LossDef(self._logprob_metric_name)) if self._config.metrics != GRPOMetricsLevel.none: - name = self._name - defs += [ - LossDef(f"{name}_old_logprobs"), - LossDef(f"{name}_ratio_new_old"), - LossDef(f"{name}_ratio_new_old_sum"), - LossDef(f"{name}_ratio_new_old_squared_sum"), - LossDef(f"{name}_kl_new_old"), - LossDef(f"{name}_clipped_ratio_fraction"), - LossDef(f"{name}_advantage"), - LossDef(f"{name}_max_advantage", reduction=ReductionType.maximum), - LossDef(f"{name}_min_advantage", reduction=ReductionType.minimum), - LossDef(f"{name}_num_tokens"), - ] + defs.extend( + [ + LossDef(f"{self._name}_old_logprobs"), + LossDef(f"{self._name}_ratio_new_old"), + LossDef(f"{self._name}_ratio_new_old_sum"), + LossDef(f"{self._name}_ratio_new_old_squared_sum"), + LossDef(f"{self._name}_kl_new_old"), + LossDef(f"{self._name}_clipped_ratio_fraction"), + LossDef(f"{self._name}_advantage"), + LossDef(f"{self._name}_max_advantage", reduction=ReductionType.maximum), + LossDef(f"{self._name}_min_advantage", reduction=ReductionType.minimum), + LossDef(f"{self._name}_num_tokens"), + ] + ) if self._config.metrics == GRPOMetricsLevel.with_entropy: - defs.append(LossDef(f"{name}_entropy")) + defs.append(LossDef(f"{self._name}_entropy")) return defs def get_preprocessing_config( @@ -179,23 +200,62 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" -@torch.compile def compute_grpo_metrics( logits: torch.Tensor, # (*batch, vocab_local) target: torch.Tensor, # (*batch,) - old_log_probabilities: torch.Tensor, # (*batch,) advantages: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) label_counts: torch.Tensor, # (*batch,) — global per-sequence count broadcast per token + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + group: torch.distributed.ProcessGroup | None = None, + compute_entropy: bool = False, +) -> GRPOMetrics: + return GRPOMetrics( + *_compute_grpo_metrics( + logits, + target, + advantages, + old_log_probabilities, + label_counts, + epsilon_low, + epsilon_high, + logits_scale_factor, + group, + compute_entropy, + ) + ) + + +@torch.compile +def _compute_grpo_metrics( + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + label_counts: torch.Tensor, epsilon_low: float, epsilon_high: float, logits_scale_factor: float, group: torch.distributed.ProcessGroup | None, compute_entropy: bool, -) -> GRPOMetrics: +) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor | None, +]: loss_mask = target >= 0 mask = loss_mask.float() - denom = label_counts.float().clamp(min=1) - masked = mask / denom + masked = mask / label_counts.float().clamp(min=1) logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) @@ -204,29 +264,29 @@ def compute_grpo_metrics( log_ratio = new_log_probs - old_log_probabilities ratio = log_ratio.exp() clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) - # Schulman k3 KL approximation: exp(r) - r - 1 + # k3 kl = ratio - log_ratio - 1.0 neg_inf = advantages.new_full((), float("-inf")) pos_inf = advantages.new_full((), float("inf")) - entropy = None + entropy: torch.Tensor | None = None if compute_entropy: entropy_per_token = sum_exp_logits.log() - (exp_logits * logits_norm).sum(-1) / sum_exp_logits entropy = (entropy_per_token * masked).sum() - return GRPOMetrics( - old_logprobs=(old_log_probabilities * masked).sum(), - ratio_new_old=(ratio * masked).sum(), - ratio_new_old_sum=(ratio * mask).sum(), - ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), - kl_new_old=(kl * masked).sum(), - clipped_ratio_fraction=(clipped.float() * masked).sum(), - advantage=(advantages * masked).sum(), - max_advantage=torch.where(loss_mask, advantages, neg_inf).max(), - min_advantage=torch.where(loss_mask, advantages, pos_inf).min(), - num_tokens=mask.sum(), - entropy=entropy, + return ( + (old_log_probabilities * masked).sum(), + (ratio * masked).sum(), + (ratio * mask).sum(), + (ratio * ratio * mask).sum(), + (kl * masked).sum(), + (clipped.float() * masked).sum(), + (advantages * masked).sum(), + torch.where(loss_mask, advantages, neg_inf).max(), + torch.where(loss_mask, advantages, pos_inf).min(), + mask.sum(), + entropy, ) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 8b3df6aa3..79b9e5f79 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -1,3 +1,4 @@ +import dataclasses import pathlib import random @@ -16,7 +17,11 @@ from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss -from fast_llm.layers.language_model.loss.grpo import compute_grpo_metrics, fused_grpo_loss_forward_backward +from fast_llm.layers.language_model.loss.grpo import ( + GRPOMetrics, + compute_grpo_metrics, + fused_grpo_loss_forward_backward, +) from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import fused_z_loss_forward_backward, z_loss from fast_llm.utils import Assert @@ -131,7 +136,7 @@ def reference_grpo_metrics( epsilon_high: float, logits_scale_factor: float, compute_entropy: bool, -) -> dict[str, torch.Tensor]: +) -> GRPOMetrics: log_softmax = torch.nn.functional.log_softmax(logits.float() * logits_scale_factor, dim=-1) loss_mask = target >= 0 mask = loss_mask.float() @@ -143,23 +148,24 @@ def reference_grpo_metrics( clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) kl = ratio - log_ratio - 1.0 - metrics = { - "old_logprobs": (old_log_probabilities.float() * masked).sum(), - "ratio_new_old": (ratio * masked).sum(), - "ratio_new_old_sum": (ratio * mask).sum(), - "ratio_new_old_squared_sum": (ratio * ratio * mask).sum(), - "kl_new_old": (kl * masked).sum(), - "clipped_ratio_fraction": (clipped.float() * masked).sum(), - "advantage": (advantages.float() * masked).sum(), - "max_advantage": advantages[loss_mask].max(), - "min_advantage": advantages[loss_mask].min(), - "num_tokens": mask.sum(), - "entropy": None, - } + entropy = None if compute_entropy: entropy_per_token = -(log_softmax.exp() * log_softmax).sum(-1) - metrics["entropy"] = (entropy_per_token * masked).sum() - return metrics + entropy = (entropy_per_token * masked).sum() + + return GRPOMetrics( + old_logprobs=(old_log_probabilities.float() * masked).sum(), + ratio_new_old=(ratio * masked).sum(), + ratio_new_old_sum=(ratio * mask).sum(), + ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), + kl_new_old=(kl * masked).sum(), + clipped_ratio_fraction=(clipped.float() * masked).sum(), + advantage=(advantages.float() * masked).sum(), + max_advantage=advantages[loss_mask].max(), + min_advantage=advantages[loss_mask].min(), + num_tokens=mask.sum(), + entropy=entropy, + ) def reference_grpo_loss( @@ -345,18 +351,26 @@ def _test_grpo_loss( Assert.rms_close_relative(new_logprobs_triton, new_logprobs_fused, 1e-5, 1e-6) +def _check_grpo_metrics(ref: GRPOMetrics, got: GRPOMetrics, threshold: float) -> None: + for field in dataclasses.fields(GRPOMetrics): + ref_value = getattr(ref, field.name) + got_value = getattr(got, field.name) + if ref_value is None: + assert got_value is None, field.name + else: + Assert.rms_close_relative(got_value, ref_value, threshold, 1e-6) + + def _test_grpo_metrics( batch_shape, num_columns, logits_scale_factor, loss_masking, dtype, compute_entropy, group=None ): logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( num_columns, loss_masking, batch_shape, dtype ) - num_labels = max(int((target >= 0).sum().item()), 1) - label_counts = torch.where( - target >= 0, - torch.full(batch_shape, num_labels, dtype=torch.int32, device=target.device), - torch.zeros(batch_shape, dtype=torch.int32, device=target.device), - ) + # Different denominators per position so the per-token-mean broadcasting is exercised. + label_counts = (torch.arange(target.numel(), device=target.device).reshape(target.shape) % 5 + 1).to( + torch.int32 + ) * (target >= 0) ref = reference_grpo_metrics( logits, @@ -372,8 +386,8 @@ def _test_grpo_metrics( got = compute_grpo_metrics( split_op(logits, group, -1).contiguous(), target, - old_log_probabilities, advantages, + old_log_probabilities, label_counts, epsilon_low=0.2, epsilon_high=0.2, @@ -381,12 +395,7 @@ def _test_grpo_metrics( group=group, compute_entropy=compute_entropy, ) - threshold = 1e-5 if dtype == DataType.float32 else 1e-4 - for key, ref_value in ref.items(): - if ref_value is None: - assert getattr(got, key) is None - else: - Assert.rms_close_relative(getattr(got, key), ref_value, threshold, 1e-6) + _check_grpo_metrics(ref, got, threshold=1e-5 if dtype == DataType.float32 else 1e-4) def _test_z_loss( From b0852fdbb97b26c2dad53f7ef8567aa57f309235 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 13:44:18 -0400 Subject: [PATCH 6/7] grpo: GRPOMetrics as NamedTuple MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit NamedTuple is a tuple subclass that dynamo handles natively, so the previous wrapper/inner split (added to dodge a dataclass graph-break) collapses into one @torch.compile function. Field order now lives exactly once — on the class. Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/grpo.py | 70 +++++---------------- tests/layers/test_lm_losses.py | 9 ++- 2 files changed, 18 insertions(+), 61 deletions(-) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 8b2ec70c7..dc134c652 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -1,4 +1,3 @@ -import dataclasses import functools import typing @@ -18,8 +17,7 @@ from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -@dataclasses.dataclass -class GRPOMetrics: +class GRPOMetrics(typing.NamedTuple): old_logprobs: torch.Tensor ratio_new_old: torch.Tensor ratio_new_old_sum: torch.Tensor @@ -200,6 +198,7 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" +@torch.compile def compute_grpo_metrics( logits: torch.Tensor, # (*batch, vocab_local) target: torch.Tensor, # (*batch,) @@ -212,47 +211,6 @@ def compute_grpo_metrics( group: torch.distributed.ProcessGroup | None = None, compute_entropy: bool = False, ) -> GRPOMetrics: - return GRPOMetrics( - *_compute_grpo_metrics( - logits, - target, - advantages, - old_log_probabilities, - label_counts, - epsilon_low, - epsilon_high, - logits_scale_factor, - group, - compute_entropy, - ) - ) - - -@torch.compile -def _compute_grpo_metrics( - logits: torch.Tensor, - target: torch.Tensor, - advantages: torch.Tensor, - old_log_probabilities: torch.Tensor, - label_counts: torch.Tensor, - epsilon_low: float, - epsilon_high: float, - logits_scale_factor: float, - group: torch.distributed.ProcessGroup | None, - compute_entropy: bool, -) -> tuple[ - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor, - torch.Tensor | None, -]: loss_mask = target >= 0 mask = loss_mask.float() masked = mask / label_counts.float().clamp(min=1) @@ -275,18 +233,18 @@ def _compute_grpo_metrics( entropy_per_token = sum_exp_logits.log() - (exp_logits * logits_norm).sum(-1) / sum_exp_logits entropy = (entropy_per_token * masked).sum() - return ( - (old_log_probabilities * masked).sum(), - (ratio * masked).sum(), - (ratio * mask).sum(), - (ratio * ratio * mask).sum(), - (kl * masked).sum(), - (clipped.float() * masked).sum(), - (advantages * masked).sum(), - torch.where(loss_mask, advantages, neg_inf).max(), - torch.where(loss_mask, advantages, pos_inf).min(), - mask.sum(), - entropy, + return GRPOMetrics( + old_logprobs=(old_log_probabilities * masked).sum(), + ratio_new_old=(ratio * masked).sum(), + ratio_new_old_sum=(ratio * mask).sum(), + ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), + kl_new_old=(kl * masked).sum(), + clipped_ratio_fraction=(clipped.float() * masked).sum(), + advantage=(advantages * masked).sum(), + max_advantage=torch.where(loss_mask, advantages, neg_inf).max(), + min_advantage=torch.where(loss_mask, advantages, pos_inf).min(), + num_tokens=mask.sum(), + entropy=entropy, ) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 79b9e5f79..19200476a 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -1,4 +1,3 @@ -import dataclasses import pathlib import random @@ -352,11 +351,11 @@ def _test_grpo_loss( def _check_grpo_metrics(ref: GRPOMetrics, got: GRPOMetrics, threshold: float) -> None: - for field in dataclasses.fields(GRPOMetrics): - ref_value = getattr(ref, field.name) - got_value = getattr(got, field.name) + for name in GRPOMetrics._fields: + ref_value = getattr(ref, name) + got_value = getattr(got, name) if ref_value is None: - assert got_value is None, field.name + assert got_value is None, name else: Assert.rms_close_relative(got_value, ref_value, threshold, 1e-6) From 61ad4f75b8905b9c9b597cf8c5bb89ae01a41d05 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 5 May 2026 14:12:03 -0400 Subject: [PATCH 7/7] grpo: fix entropy under tensor-parallel + minor review fixes - Entropy under vocab-parallel TP was wrong: the dot-product term (exp_logits * logits_norm).sum(-1) summed only the local vocab slice, so dividing by the global sum_exp_logits gave a per-rank fragment instead of the full E_p[logit_norm]. All-reduce the partial sum. - Replace the verbose pipeline-parallel guard with Assert.custom; the field description already explains the constraint. - Drop the cryptic `# k3` comment. - Match _register_extra_metrics losses annotation to the base class (dict | None). Co-Authored-By: Claude Opus 4.7 (1M context) --- fast_llm/layers/language_model/loss/grpo.py | 26 ++++++++++++--------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index dc134c652..4bbaeb581 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -3,6 +3,7 @@ import torch +from fast_llm.core.distributed import ReduceOp, all_reduce from fast_llm.engine.base_model.config import LossDef, ReductionType from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig @@ -15,6 +16,7 @@ LanguageModelLossKwargs, ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.utils import Assert class GRPOMetrics(typing.NamedTuple): @@ -58,14 +60,11 @@ def __init__( weight=weight, register_loss=register_loss, ) - # MAX/MIN reductions are unsafe under pipeline parallelism: ranks without this loss layer - # contribute a torch.zeros([1]) placeholder in LossDef.reduce, which corrupts the extremum - # whenever the real value has the opposite sign. - if config.metrics != GRPOMetricsLevel.none and distributed_config.pipeline_parallel > 1: - raise NotImplementedError( - "GRPO extra metrics are not supported with pipeline_parallel > 1 " - "(MAX/MIN advantage reductions would be corrupted by the zero placeholder on empty pipeline ranks)." - ) + Assert.custom( + lambda metrics, pipeline_parallel: metrics == GRPOMetricsLevel.none or pipeline_parallel == 1, + config.metrics, + distributed_config.pipeline_parallel, + ) def _forward_backward( self, @@ -116,7 +115,7 @@ def _register_extra_metrics( self, logits: torch.Tensor, kwargs: dict[str, typing.Any], - losses: dict[str, list[torch.Tensor]], + losses: dict | None, split_index: int, ) -> None: metrics = compute_grpo_metrics( @@ -222,7 +221,6 @@ def compute_grpo_metrics( log_ratio = new_log_probs - old_log_probabilities ratio = log_ratio.exp() clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) - # k3 kl = ratio - log_ratio - 1.0 neg_inf = advantages.new_full((), float("-inf")) @@ -230,7 +228,13 @@ def compute_grpo_metrics( entropy: torch.Tensor | None = None if compute_entropy: - entropy_per_token = sum_exp_logits.log() - (exp_logits * logits_norm).sum(-1) / sum_exp_logits + # exp_logits and logits_norm are local vocab slices — sum over the local slice, then all-reduce + # across the tensor-parallel group to recover the global E_p[logit_norm] before dividing by the + # already-global sum_exp_logits. + weighted_logits_sum = (exp_logits * logits_norm).sum(-1) + if group is not None: + all_reduce(weighted_logits_sum, op=ReduceOp.SUM, group=group) + entropy_per_token = sum_exp_logits.log() - weighted_logits_sum / sum_exp_logits entropy = (entropy_per_token * masked).sum() return GRPOMetrics(