diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 4381aa5d9..70cf8806a 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -1,3 +1,4 @@ +import enum import typing import warnings @@ -193,6 +194,12 @@ def loss_class(self) -> "type[LanguageModelZLoss]": return LanguageModelZLoss +class GRPOMetricsLevel(enum.StrEnum): + none = "none" + basic = "basic" + with_entropy = "with_entropy" + + @config_class(dynamic_type={LanguageModelLossConfig: "grpo"}) class LanguageModelGRPOLossConfig(LanguageModelLossConfig): @@ -205,6 +212,16 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc="Enable triton implementation. Default: use if available.", hint=FieldHint.expert, ) + metrics: GRPOMetricsLevel = Field( + default=GRPOMetricsLevel.none, + desc=( + "Additional GRPO metrics to log. " + "`basic`: per-token ratio, KL, and advantage statistics. " + "`with_entropy`: also log per-token entropy. " + "Not supported with pipeline_parallel > 1." + ), + hint=FieldHint.feature, + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index cc6cbf726..4bbaeb581 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -3,16 +3,69 @@ import torch -from fast_llm.engine.base_model.config import LossDef +from fast_llm.core.distributed import ReduceOp, all_reduce +from fast_llm.engine.base_model.config import LossDef, ReductionType +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base from fast_llm.functional.utils import reduce_losses from fast_llm.layers.language_model.config import LanguageModelKwargs -from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.config import ( + GRPOMetricsLevel, + LanguageModelGRPOLossConfig, + LanguageModelLossKwargs, +) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.utils import Assert + + +class GRPOMetrics(typing.NamedTuple): + old_logprobs: torch.Tensor + ratio_new_old: torch.Tensor + ratio_new_old_sum: torch.Tensor + ratio_new_old_squared_sum: torch.Tensor + kl_new_old: torch.Tensor + clipped_ratio_fraction: torch.Tensor + advantage: torch.Tensor + max_advantage: torch.Tensor + min_advantage: torch.Tensor + num_tokens: torch.Tensor + entropy: torch.Tensor | None class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]): + def __init__( + self, + config: ConfigType, + distributed_config: DistributedConfig, + *, + name: str, + prediction_distance: int = 1, + prediction_heads: int = 1, + vocab_parallel: bool = False, + num_splits: int = 1, + logits_scale_factor: float = 1.0, + weight: float = 1.0, + register_loss: bool = False, + ): + super().__init__( + config, + distributed_config, + name=name, + prediction_distance=prediction_distance, + prediction_heads=prediction_heads, + vocab_parallel=vocab_parallel, + num_splits=num_splits, + logits_scale_factor=logits_scale_factor, + weight=weight, + register_loss=register_loss, + ) + Assert.custom( + lambda metrics, pipeline_parallel: metrics == GRPOMetricsLevel.none or pipeline_parallel == 1, + config.metrics, + distributed_config.pipeline_parallel, + ) + def _forward_backward( self, logits: "torch.Tensor", @@ -51,10 +104,88 @@ def _forward_backward( self._register_loss( self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM ) + + # Skip the extra softmax pass when there is nothing to register. + if losses is not None and self._config.metrics != GRPOMetricsLevel.none: + self._register_extra_metrics(logits, kwargs, losses, split_index) + return loss, grad + def _register_extra_metrics( + self, + logits: torch.Tensor, + kwargs: dict[str, typing.Any], + losses: dict | None, + split_index: int, + ) -> None: + metrics = compute_grpo_metrics( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index), + self._config.epsilon_low, + self._config.epsilon_high, + self._logits_scale_factor, + group=self._parallel_dim.group if self._vocab_parallel else None, + compute_entropy=self._config.metrics == GRPOMetricsLevel.with_entropy, + ) + + num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch] + + for attr in ( + "old_logprobs", + "ratio_new_old", + "kl_new_old", + "clipped_ratio_fraction", + "advantage", + ): + self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr) / num_documents, losses) + + for attr in ( + "ratio_new_old_sum", + "ratio_new_old_squared_sum", + "num_tokens", + ): + self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr), losses) + + self._register_loss( + f"{self._name}_max_advantage", + metrics.max_advantage, + losses, + reduce_op=torch.distributed.ReduceOp.MAX, + ) + self._register_loss( + f"{self._name}_min_advantage", + metrics.min_advantage, + losses, + reduce_op=torch.distributed.ReduceOp.MIN, + ) + + if metrics.entropy is not None: + self._register_loss(f"{self._name}_entropy", metrics.entropy / num_documents, losses) + def get_loss_definitions(self) -> list[LossDef]: - return super().get_loss_definitions() + [LossDef(self._logprob_metric_name)] + defs = super().get_loss_definitions() + defs.append(LossDef(self._logprob_metric_name)) + if self._config.metrics != GRPOMetricsLevel.none: + defs.extend( + [ + LossDef(f"{self._name}_old_logprobs"), + LossDef(f"{self._name}_ratio_new_old"), + LossDef(f"{self._name}_ratio_new_old_sum"), + LossDef(f"{self._name}_ratio_new_old_squared_sum"), + LossDef(f"{self._name}_kl_new_old"), + LossDef(f"{self._name}_clipped_ratio_fraction"), + LossDef(f"{self._name}_advantage"), + LossDef(f"{self._name}_max_advantage", reduction=ReductionType.maximum), + LossDef(f"{self._name}_min_advantage", reduction=ReductionType.minimum), + LossDef(f"{self._name}_num_tokens"), + ] + ) + if self._config.metrics == GRPOMetricsLevel.with_entropy: + defs.append(LossDef(f"{self._name}_entropy")) + return defs def get_preprocessing_config( self, @@ -66,6 +197,61 @@ def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" +@torch.compile +def compute_grpo_metrics( + logits: torch.Tensor, # (*batch, vocab_local) + target: torch.Tensor, # (*batch,) + advantages: torch.Tensor, # (*batch,) + old_log_probabilities: torch.Tensor, # (*batch,) + label_counts: torch.Tensor, # (*batch,) — global per-sequence count broadcast per token + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + group: torch.distributed.ProcessGroup | None = None, + compute_entropy: bool = False, +) -> GRPOMetrics: + loss_mask = target >= 0 + mask = loss_mask.float() + masked = mask / label_counts.float().clamp(min=1) + + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group) + new_log_probs = predicted_logits - sum_exp_logits.log() + + log_ratio = new_log_probs - old_log_probabilities + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) + kl = ratio - log_ratio - 1.0 + + neg_inf = advantages.new_full((), float("-inf")) + pos_inf = advantages.new_full((), float("inf")) + + entropy: torch.Tensor | None = None + if compute_entropy: + # exp_logits and logits_norm are local vocab slices — sum over the local slice, then all-reduce + # across the tensor-parallel group to recover the global E_p[logit_norm] before dividing by the + # already-global sum_exp_logits. + weighted_logits_sum = (exp_logits * logits_norm).sum(-1) + if group is not None: + all_reduce(weighted_logits_sum, op=ReduceOp.SUM, group=group) + entropy_per_token = sum_exp_logits.log() - weighted_logits_sum / sum_exp_logits + entropy = (entropy_per_token * masked).sum() + + return GRPOMetrics( + old_logprobs=(old_log_probabilities * masked).sum(), + ratio_new_old=(ratio * masked).sum(), + ratio_new_old_sum=(ratio * mask).sum(), + ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), + kl_new_old=(kl * masked).sum(), + clipped_ratio_fraction=(clipped.float() * masked).sum(), + advantage=(advantages * masked).sum(), + max_advantage=torch.where(loss_mask, advantages, neg_inf).max(), + min_advantage=torch.where(loss_mask, advantages, pos_inf).min(), + num_tokens=mask.sum(), + entropy=entropy, + ) + + @torch.compile def fused_grpo_loss_forward_backward( logits: torch.Tensor, # (*batch, vocab) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 9b93aeb66..19200476a 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -16,7 +16,11 @@ from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss -from fast_llm.layers.language_model.loss.grpo import fused_grpo_loss_forward_backward +from fast_llm.layers.language_model.loss.grpo import ( + GRPOMetrics, + compute_grpo_metrics, + fused_grpo_loss_forward_backward, +) from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import fused_z_loss_forward_backward, z_loss from fast_llm.utils import Assert @@ -121,6 +125,48 @@ def reference_dpo_loss( return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() +def reference_grpo_metrics( + logits: torch.Tensor, + target: torch.Tensor, + advantages: torch.Tensor, + old_log_probabilities: torch.Tensor, + label_counts: torch.Tensor, + epsilon_low: float, + epsilon_high: float, + logits_scale_factor: float, + compute_entropy: bool, +) -> GRPOMetrics: + log_softmax = torch.nn.functional.log_softmax(logits.float() * logits_scale_factor, dim=-1) + loss_mask = target >= 0 + mask = loss_mask.float() + masked = mask / label_counts.float().clamp(min=1) + + new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) + log_ratio = new_log_probs - old_log_probabilities.float() + ratio = log_ratio.exp() + clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high) + kl = ratio - log_ratio - 1.0 + + entropy = None + if compute_entropy: + entropy_per_token = -(log_softmax.exp() * log_softmax).sum(-1) + entropy = (entropy_per_token * masked).sum() + + return GRPOMetrics( + old_logprobs=(old_log_probabilities.float() * masked).sum(), + ratio_new_old=(ratio * masked).sum(), + ratio_new_old_sum=(ratio * mask).sum(), + ratio_new_old_squared_sum=(ratio * ratio * mask).sum(), + kl_new_old=(kl * masked).sum(), + clipped_ratio_fraction=(clipped.float() * masked).sum(), + advantage=(advantages.float() * masked).sum(), + max_advantage=advantages[loss_mask].max(), + min_advantage=advantages[loss_mask].min(), + num_tokens=mask.sum(), + entropy=entropy, + ) + + def reference_grpo_loss( logits: torch.Tensor, labels: torch.Tensor, @@ -304,6 +350,53 @@ def _test_grpo_loss( Assert.rms_close_relative(new_logprobs_triton, new_logprobs_fused, 1e-5, 1e-6) +def _check_grpo_metrics(ref: GRPOMetrics, got: GRPOMetrics, threshold: float) -> None: + for name in GRPOMetrics._fields: + ref_value = getattr(ref, name) + got_value = getattr(got, name) + if ref_value is None: + assert got_value is None, name + else: + Assert.rms_close_relative(got_value, ref_value, threshold, 1e-6) + + +def _test_grpo_metrics( + batch_shape, num_columns, logits_scale_factor, loss_masking, dtype, compute_entropy, group=None +): + logits, target, advantages, old_log_probabilities = _get_grpo_loss_inputs( + num_columns, loss_masking, batch_shape, dtype + ) + # Different denominators per position so the per-token-mean broadcasting is exercised. + label_counts = (torch.arange(target.numel(), device=target.device).reshape(target.shape) % 5 + 1).to( + torch.int32 + ) * (target >= 0) + + ref = reference_grpo_metrics( + logits, + target, + advantages, + old_log_probabilities, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=logits_scale_factor, + compute_entropy=compute_entropy, + ) + got = compute_grpo_metrics( + split_op(logits, group, -1).contiguous(), + target, + advantages, + old_log_probabilities, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=logits_scale_factor, + group=group, + compute_entropy=compute_entropy, + ) + _check_grpo_metrics(ref, got, threshold=1e-5 if dtype == DataType.float32 else 1e-4) + + def _test_z_loss( batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, accumulate, group=None ): @@ -421,6 +514,27 @@ def test_grpo_loss( ) +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size", "accumulate"), + _LOSS_PARAMETERS, +) +@pytest.mark.parametrize("compute_entropy", (False, True)) +def test_grpo_metrics( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, + block_size, + accumulate, + compute_entropy, +): + _test_grpo_metrics(batch_shape, num_columns, logits_scale_factor, loss_masking, dtype, compute_entropy) + + @pytest.mark.skip(reason="DPO loss is broken") def test_dpo_loss(): logits = torch.normal(0, 1, (200, 100)) @@ -498,6 +612,20 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa accumulate, test_context.group, ) + # GRPO metrics + for compute_entropy in (False, True): + with test_context.subtest(base_path, f"grpo_metrics-{compute_entropy}-{suffix}", 2) as subtest: + if subtest.do_run: + torch.manual_seed((seed + hash(subtest.name)) % 2**32) + _test_grpo_metrics( + batch_shape, + num_columns, + logits_scale_factor, + loss_masking, + dtype, + compute_entropy, + test_context.group, + ) @pytest.mark.slow @@ -538,6 +666,8 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): ), "z_loss", "grpo", + "grpo_metrics-False", + "grpo_metrics-True", ), ) def test_lm_loss_distributed(