Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions fast_llm/layers/language_model/loss/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import enum
import typing
import warnings

Expand Down Expand Up @@ -193,6 +194,12 @@ def loss_class(self) -> "type[LanguageModelZLoss]":
return LanguageModelZLoss


class GRPOMetricsLevel(enum.StrEnum):
none = "none"
basic = "basic"
with_entropy = "with_entropy"


@config_class(dynamic_type={LanguageModelLossConfig: "grpo"})
class LanguageModelGRPOLossConfig(LanguageModelLossConfig):

Expand All @@ -205,6 +212,16 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig):
desc="Enable triton implementation. Default: use if available.",
hint=FieldHint.expert,
)
metrics: GRPOMetricsLevel = Field(
default=GRPOMetricsLevel.none,
desc=(
"Additional GRPO metrics to log. "
"`basic`: per-token ratio, KL, and advantage statistics. "
"`with_entropy`: also log per-token entropy. "
"Not supported with pipeline_parallel > 1."
),
hint=FieldHint.feature,
)

@property
def loss_class(self) -> "type[LanguageModelGRPOLoss]":
Expand Down
192 changes: 189 additions & 3 deletions fast_llm/layers/language_model/loss/grpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,69 @@

import torch

from fast_llm.engine.base_model.config import LossDef
from fast_llm.core.distributed import ReduceOp, all_reduce
from fast_llm.engine.base_model.config import LossDef, ReductionType
from fast_llm.engine.distributed.config import DistributedConfig
from fast_llm.functional.config import TritonConfig
from fast_llm.functional.entropy_loss import fused_predicted_logits_from_labels, fused_softmax_base
from fast_llm.functional.utils import reduce_losses
from fast_llm.layers.language_model.config import LanguageModelKwargs
from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs
from fast_llm.layers.language_model.loss.config import (
GRPOMetricsLevel,
LanguageModelGRPOLossConfig,
LanguageModelLossKwargs,
)
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss
from fast_llm.utils import Assert


class GRPOMetrics(typing.NamedTuple):
old_logprobs: torch.Tensor
ratio_new_old: torch.Tensor
ratio_new_old_sum: torch.Tensor
ratio_new_old_squared_sum: torch.Tensor
kl_new_old: torch.Tensor
clipped_ratio_fraction: torch.Tensor
advantage: torch.Tensor
max_advantage: torch.Tensor
min_advantage: torch.Tensor
num_tokens: torch.Tensor
entropy: torch.Tensor | None


class LanguageModelGRPOLoss[ConfigType: LanguageModelGRPOLossConfig](LanguageModelLoss[ConfigType]):
def __init__(
self,
config: ConfigType,
distributed_config: DistributedConfig,
*,
name: str,
prediction_distance: int = 1,
prediction_heads: int = 1,
vocab_parallel: bool = False,
num_splits: int = 1,
logits_scale_factor: float = 1.0,
weight: float = 1.0,
register_loss: bool = False,
):
super().__init__(
config,
distributed_config,
name=name,
prediction_distance=prediction_distance,
prediction_heads=prediction_heads,
vocab_parallel=vocab_parallel,
num_splits=num_splits,
logits_scale_factor=logits_scale_factor,
weight=weight,
register_loss=register_loss,
)
Assert.custom(
lambda metrics, pipeline_parallel: metrics == GRPOMetricsLevel.none or pipeline_parallel == 1,
config.metrics,
distributed_config.pipeline_parallel,
)

def _forward_backward(
self,
logits: "torch.Tensor",
Expand Down Expand Up @@ -51,10 +104,88 @@ def _forward_backward(
self._register_loss(
self._logprob_metric_name, new_logprobs_mean, losses, reduce_op=torch.distributed.ReduceOp.SUM
)

# Skip the extra softmax pass when there is nothing to register.
if losses is not None and self._config.metrics != GRPOMetricsLevel.none:
self._register_extra_metrics(logits, kwargs, losses, split_index)

return loss, grad

def _register_extra_metrics(
self,
logits: torch.Tensor,
kwargs: dict[str, typing.Any],
losses: dict | None,
split_index: int,
) -> None:
metrics = compute_grpo_metrics(
logits,
self._get_labels(kwargs, split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index),
self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index),
self._config.epsilon_low,
self._config.epsilon_high,
self._logits_scale_factor,
group=self._parallel_dim.group if self._vocab_parallel else None,
compute_entropy=self._config.metrics == GRPOMetricsLevel.with_entropy,
)

num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch]

for attr in (
"old_logprobs",
"ratio_new_old",
"kl_new_old",
"clipped_ratio_fraction",
"advantage",
):
self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr) / num_documents, losses)

for attr in (
"ratio_new_old_sum",
"ratio_new_old_squared_sum",
"num_tokens",
):
self._register_loss(f"{self._name}_{attr}", getattr(metrics, attr), losses)

self._register_loss(
f"{self._name}_max_advantage",
metrics.max_advantage,
losses,
reduce_op=torch.distributed.ReduceOp.MAX,
)
self._register_loss(
f"{self._name}_min_advantage",
metrics.min_advantage,
losses,
reduce_op=torch.distributed.ReduceOp.MIN,
)

if metrics.entropy is not None:
self._register_loss(f"{self._name}_entropy", metrics.entropy / num_documents, losses)

def get_loss_definitions(self) -> list[LossDef]:
return super().get_loss_definitions() + [LossDef(self._logprob_metric_name)]
defs = super().get_loss_definitions()
defs.append(LossDef(self._logprob_metric_name))
if self._config.metrics != GRPOMetricsLevel.none:
defs.extend(
[
LossDef(f"{self._name}_old_logprobs"),
LossDef(f"{self._name}_ratio_new_old"),
LossDef(f"{self._name}_ratio_new_old_sum"),
LossDef(f"{self._name}_ratio_new_old_squared_sum"),
LossDef(f"{self._name}_kl_new_old"),
LossDef(f"{self._name}_clipped_ratio_fraction"),
LossDef(f"{self._name}_advantage"),
LossDef(f"{self._name}_max_advantage", reduction=ReductionType.maximum),
LossDef(f"{self._name}_min_advantage", reduction=ReductionType.minimum),
LossDef(f"{self._name}_num_tokens"),
]
)
if self._config.metrics == GRPOMetricsLevel.with_entropy:
defs.append(LossDef(f"{self._name}_entropy"))
return defs

def get_preprocessing_config(
self,
Expand All @@ -66,6 +197,61 @@ def _logprob_metric_name(self) -> str:
return f"{self._name}_new_logprobs"


@torch.compile
def compute_grpo_metrics(
logits: torch.Tensor, # (*batch, vocab_local)
target: torch.Tensor, # (*batch,)
advantages: torch.Tensor, # (*batch,)
old_log_probabilities: torch.Tensor, # (*batch,)
label_counts: torch.Tensor, # (*batch,) — global per-sequence count broadcast per token
epsilon_low: float = 0.2,
epsilon_high: float = 0.2,
logits_scale_factor: float = 1.0,
group: torch.distributed.ProcessGroup | None = None,
compute_entropy: bool = False,
) -> GRPOMetrics:
loss_mask = target >= 0
mask = loss_mask.float()
masked = mask / label_counts.float().clamp(min=1)

logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group)
predicted_logits, _, _ = fused_predicted_logits_from_labels(logits_norm, target, loss_mask, group)
new_log_probs = predicted_logits - sum_exp_logits.log()

log_ratio = new_log_probs - old_log_probabilities
ratio = log_ratio.exp()
clipped = (ratio < 1.0 - epsilon_low) | (ratio > 1.0 + epsilon_high)
kl = ratio - log_ratio - 1.0

neg_inf = advantages.new_full((), float("-inf"))
pos_inf = advantages.new_full((), float("inf"))

entropy: torch.Tensor | None = None
if compute_entropy:
# exp_logits and logits_norm are local vocab slices — sum over the local slice, then all-reduce
# across the tensor-parallel group to recover the global E_p[logit_norm] before dividing by the
# already-global sum_exp_logits.
weighted_logits_sum = (exp_logits * logits_norm).sum(-1)
if group is not None:
all_reduce(weighted_logits_sum, op=ReduceOp.SUM, group=group)
entropy_per_token = sum_exp_logits.log() - weighted_logits_sum / sum_exp_logits
entropy = (entropy_per_token * masked).sum()

return GRPOMetrics(
old_logprobs=(old_log_probabilities * masked).sum(),
ratio_new_old=(ratio * masked).sum(),
ratio_new_old_sum=(ratio * mask).sum(),
ratio_new_old_squared_sum=(ratio * ratio * mask).sum(),
kl_new_old=(kl * masked).sum(),
clipped_ratio_fraction=(clipped.float() * masked).sum(),
advantage=(advantages * masked).sum(),
max_advantage=torch.where(loss_mask, advantages, neg_inf).max(),
min_advantage=torch.where(loss_mask, advantages, pos_inf).min(),
num_tokens=mask.sum(),
entropy=entropy,
)


@torch.compile
def fused_grpo_loss_forward_backward(
logits: torch.Tensor, # (*batch, vocab)
Expand Down
Loading
Loading