From b07b999c8675893fa6aaca752dc45b5bbbabeee9 Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 28 Apr 2026 09:59:02 +0000 Subject: [PATCH 1/7] gspo: add sequence-level IS-ratio clipping loss Implements GSPO (geometric-mean sequence-level policy-gradient loss) as an alternative to the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo". Key changes: - data pipeline: expose per-token document_index when return_document_index=True - LanguageModelKwargs.document_index: new kwarg constant - LanguageModelLoss: store SDP dim for cross-rank segment aggregation - grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across SDP ranks before computing segment-level R_s and A_s; gradient derivation exploits tok_count cancellation so every token in a segment gets the same gradient factor R_s * clip_indicator_s - tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed, ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff, per-token metrics unchanged) --- fast_llm/data/document/config.py | 1 + fast_llm/data/document/language_model.py | 13 +- fast_llm/layers/language_model/config.py | 1 + fast_llm/layers/language_model/loss/config.py | 5 + fast_llm/layers/language_model/loss/grpo.py | 186 ++++++- fast_llm/layers/language_model/loss/loss.py | 2 + tests/layers/test_gspo_loss.py | 461 ++++++++++++++++++ 7 files changed, 642 insertions(+), 27 deletions(-) create mode 100644 tests/layers/test_gspo_loss.py diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 352311b51..ad6a7305f 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -79,6 +79,7 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): use_preference_spans: bool = Field(default=False) use_grpo_data: bool = Field(default=False) return_label_counts: bool = Field(default=False) + return_document_index: bool = Field(default=False) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 7821b81c5..8dab70efb 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -35,6 +35,7 @@ class LanguageModelTargetInput(ModelInput): advantages: torch.Tensor | None = None old_log_probabilities: torch.Tensor | None = None label_counts: torch.Tensor | None = None + document_index: torch.Tensor | None = None num_labels: int | None = None num_labels_in_batch: int | None = None @@ -84,6 +85,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: LanguageModelKwargs.advantages: [target.advantages for target in self.targets], LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], + LanguageModelKwargs.document_index: [target.document_index for target in self.targets], LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets], } if self.image_patches is not None: @@ -177,7 +179,11 @@ def _set_target_inputs( document_begin += length mask = labels >= 0 - label_counts = self._get_label_counts(mask) if config.return_label_counts else None + label_counts, document_index = ( + self._get_label_counts(mask) + if config.return_label_counts or config.return_document_index + else (None, None) + ) for input_index, model_input in enumerate(model_inputs): label_end = model_input.sequence_k_dim.size + prediction_distance @@ -188,6 +194,7 @@ def _set_target_inputs( tokens=labels[label_begin:label_end].clone(), mask=mask[label_begin:label_end] if config.return_prediction_mask else None, label_counts=label_counts[label_begin:label_end] if config.return_label_counts else None, + document_index=document_index[label_begin:label_end] if config.return_document_index else None, # Set value for the first input only so `share_batch_data` generated the correct sum. # TODO: ====== Make optional? num_labels=( @@ -202,7 +209,7 @@ def _set_target_inputs( model_input.targets.append(target_input) - def _get_label_counts(self, mask: torch.Tensor): + def _get_label_counts(self, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # Count the number of non-masked labels in each document through cumulative sums. mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0) @@ -214,4 +221,4 @@ def _get_label_counts(self, mask: torch.Tensor): document_index = torch.searchsorted( length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" ) - return labels_per_document[document_index] + return labels_per_document[document_index], document_index diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4a8efdab6..1de722cae 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -25,6 +25,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs): sample_map = "sample_map" embedding_map = "embedding_map" num_documents_in_batch = "num_documents_in_batch" + document_index = "document_index" # TODO: These are generic phase = "phase" loss_mask = "loss_mask" diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 4f91724a2..a5e34dd3e 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -198,6 +198,11 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): _abstract: typing.ClassVar[bool] = False + policy_loss: str = Field( + default="grpo", + desc="Policy loss algorithm: 'grpo' (per-token IS ratio clipping) or 'gspo' (sequence-level geometric-mean clipping).", + valid=check_field(Assert.incl, ["grpo", "gspo"]), + ) epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") use_triton: bool | None = Field( diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index ab75d2f01..2f9c190e6 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -21,30 +21,52 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - if TritonConfig.enabled(logits.device, self._config.use_triton): - from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward - - fn = triton_grpo_loss_forward_backward + if self._config.policy_loss == "gspo": + loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), + grad_logits=grad_logits, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._logits_scale_factor, + num_labels_in_seq=( + None + if losses is None + else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) + ), + divisor=self._get_label_count(kwargs), + sdp_group=self._sdp_dim.group if self._sdp_active else None, + ) else: - fn = fused_grpo_loss_forward_backward - loss, grad, new_logprobs_mean = fn( - logits, - self._get_labels(kwargs, split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), - grad_logits=grad_logits, - grad_output=self._get_grad_output(kwargs), - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._logits_scale_factor, - num_labels_in_seq=( - None - if losses is None - else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) - ), - divisor=self._get_label_count(kwargs), - ) + if TritonConfig.enabled(logits.device, self._config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + fn = triton_grpo_loss_forward_backward + else: + fn = fused_grpo_loss_forward_backward + loss, grad, new_logprobs_mean = fn( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + grad_logits=grad_logits, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._logits_scale_factor, + num_labels_in_seq=( + None + if losses is None + else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) + ), + divisor=self._get_label_count(kwargs), + ) if new_logprobs_mean is not None: new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] @@ -141,7 +163,10 @@ def get_loss_definitions(self) -> list[LossDef]: def get_preprocessing_config( self, ) -> dict[str, typing.Any]: - return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + config = {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + if self._config.policy_loss == "gspo": + config["return_document_index"] = True + return config @functools.cached_property def _logprob_metric_name(self) -> str: @@ -222,3 +247,116 @@ def fused_grpo_loss_forward_backward( grad_logits.add_(grad) return loss, grad_logits, new_logprobs_mean + + +def fused_gspo_loss_forward_backward( + logits: torch.Tensor, # (n_tokens, vocab_local) + target: torch.Tensor, # (n_tokens,) + advantages: torch.Tensor, # (n_tokens,) + old_log_probabilities: torch.Tensor, # (n_tokens,) + document_index: torch.Tensor, # (n_tokens,) int64 — segment ID per token + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, + group: torch.distributed.ProcessGroup | None = None, # TP vocab group + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + num_labels_in_seq: torch.Tensor | None = None, # for new_logprobs_mean metric + divisor: float | None = None, + sdp_group: torch.distributed.ProcessGroup | None = None, # SDP group for cross-rank segment aggregation +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """GSPO loss: sequence-level geometric-mean IS ratio clipping. + + Each segment s gets ratio R_s = exp(mean_t(log(p_new_t/p_old_t))), clipped as a unit. + Loss = -sum_s tok_count_s * min(R_s*A_s, clip(R_s)*A_s) / divisor. + Gradient: tok_count_s cancels, so each token in segment s gets the same gradient factor R_s. + + SDP correctness: scatter_add sums are all-reduced across sdp_group before computing R_s and A_s, + ensuring correct segment-level ratios when tokens are split across ranks. + """ + if divisor is None: + divisor = float(logits.shape[0]) if logits.shape[0] > 0 else 1.0 + grad_output_scaled = None if grad_output is None else grad_output / divisor * logits_scale_factor + + loss_mask = target >= 0 + mask_float = loss_mask.float() + + # Step 1: Softmax + log probs (same as GRPO) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, target_masked, target_mask = fused_predicted_logits_from_labels( + logits_norm, target, loss_mask, group + ) + new_log_probs = predicted_logits - sum_exp_logits.log() + log_ratio = (new_log_probs - old_log_probabilities).float() + + # new_logprobs_mean: local partial sum (aggregated across SDP via LossDef.reduce, same as GRPO) + new_logprobs_mean = ( + None if num_labels_in_seq is None else (new_log_probs * mask_float / num_labels_in_seq.clamp(min=1)).sum() + ) + + # Step 2: Determine global n_segs (max doc index + 1, all-reduced across SDP) + n_segs_local = int(document_index.max().item()) + 1 if document_index.numel() > 0 else 0 + if sdp_group is not None: + n_segs_t = torch.tensor(n_segs_local, device=logits.device, dtype=torch.int64) + torch.distributed.all_reduce(n_segs_t, op=torch.distributed.ReduceOp.MAX, group=sdp_group) + n_segs = int(n_segs_t.item()) + else: + n_segs = n_segs_local + + # Step 3: Per-segment scatter_add (local contributions only) + lrn_sum = log_ratio.new_zeros(n_segs) # sum of log-ratios per segment + adv_sum = advantages.new_zeros(n_segs).float() # sum of advantages per segment + tok_sum = log_ratio.new_zeros(n_segs) # token count per segment + + if loss_mask.any() and n_segs > 0: + masked_doc_ids = document_index[loss_mask].long() + lrn_sum.index_add_(0, masked_doc_ids, log_ratio[loss_mask]) + adv_sum.index_add_(0, masked_doc_ids, advantages[loss_mask].float()) + tok_sum.index_add_(0, masked_doc_ids, torch.ones(masked_doc_ids.numel(), device=logits.device)) + + # Step 4: SDP all-reduce so every rank has global per-segment sums + if sdp_group is not None and n_segs > 0: + torch.distributed.all_reduce(lrn_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(adv_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(tok_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + + # Step 5: Segment-level ratio R_s and advantage A_s + valid = tok_sum > 0 + seg_denom = tok_sum.clamp(min=1e-6) + R = (lrn_sum / seg_denom).exp() # geometric mean IS ratio per segment + A = (adv_sum / seg_denom).detach() # mean advantage per segment (no gradient through A) + + # Step 6: GSPO loss — length-proportional weight tok_sum cancels with 1/N in gradient + surr1 = R * A + surr2 = R.clamp(1.0 - epsilon_low, 1.0 + epsilon_high) * A + loss_per_seg = -torch.minimum(surr1, surr2) * tok_sum * valid.float() + loss = loss_per_seg.sum() / divisor + + # Step 7: Gradient — broadcast segment-level factors to token level + if grad_output_scaled is not None and n_segs > 0: + # d(loss)/d(log_ratio_t) = -R_s * clip_factor_s / divisor (tok_sum cancels) + # clip_factor_s = clamp_min(A_s,0)*(R_s <= 1+eps_h) + clamp_max(A_s,0)*(R_s >= 1-eps_l) + clip_up = (R <= 1.0 + epsilon_high).float() + clip_dn = (R >= 1.0 - epsilon_low).float() + seg_grad = R * (A.clamp(min=0) * clip_up + A.clamp(max=0) * clip_dn) * valid.float() + + # Broadcast: each token gets its segment's gradient factor + token_grad = seg_grad[document_index] # (n_tokens,) + + # d(new_log_prob)/d(logits_k) = delta(k==target) - softmax_k (same chain rule as GRPO) + probability_ratio_grad = grad_output_scaled * token_grad * mask_float + + predicted_probabilities = exp_logits / sum_exp_logits.unsqueeze(-1) + grad = probability_ratio_grad.unsqueeze(-1) * predicted_probabilities.scatter_add( + -1, + target_masked.unsqueeze(-1), + -(loss_mask if target_mask is None else target_mask).unsqueeze(-1).to(torch.float32), + ) + grad = grad.to(logits.dtype) + + if grad_logits is None: + grad_logits = grad + else: + grad_logits.add_(grad) + + return loss, grad_logits, new_logprobs_mean diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 3cab2bca8..90b368e2b 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -39,6 +39,8 @@ def __init__( self._vocab_parallel = distributed_config.tensor_parallel > 1 and vocab_parallel self._sequence_parallel = distributed_config.sequence_tensor_parallel and not self._vocab_parallel self._parallel_dim = distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._sdp_dim = distributed_config.get_distributed_dim(DistributedDimNames.sequence_data) + self._sdp_active = distributed_config.sequence_data_parallel > 1 def forward_backward( self, diff --git a/tests/layers/test_gspo_loss.py b/tests/layers/test_gspo_loss.py new file mode 100644 index 000000000..46fb22673 --- /dev/null +++ b/tests/layers/test_gspo_loss.py @@ -0,0 +1,461 @@ +""" +Unit tests for fused_gspo_loss_forward_backward. + +Tests: single segment, multi-segment packed, GRPO/GSPO equivalence at ratio=1, +segment-level clipping, SDP mock, gradient check, extra metrics unchanged. +""" + +import math + +import torch + +from fast_llm.layers.language_model.loss.grpo import ( + fused_grpo_loss_forward_backward, + fused_gspo_loss_forward_backward, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" +atol = 1e-4 if device == "cuda" else 1e-5 + + +# --------------------------------------------------------------------------- +# Reference GSPO implementation +# --------------------------------------------------------------------------- + + +def _gspo_reference(logits, target, advantages, old_log_probs, doc_idx, eps_lo, eps_hi, divisor): + """Pure-PyTorch reference without compilation or distributed calls.""" + loss_mask = target >= 0 + log_softmax = torch.log_softmax(logits.float(), dim=-1) + new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) + log_ratio = (new_log_probs - old_log_probs.float()) * loss_mask.float() + + n_segs = int(doc_idx.max().item()) + 1 + lrn_sum = torch.zeros(n_segs, dtype=torch.float32) + adv_sum = torch.zeros(n_segs, dtype=torch.float32) + tok_sum = torch.zeros(n_segs, dtype=torch.float32) + for i in range(len(target)): + if loss_mask[i]: + s = doc_idx[i].item() + lrn_sum[s] += log_ratio[i].item() + adv_sum[s] += advantages[i].item() + tok_sum[s] += 1.0 + + loss = 0.0 + for s in range(n_segs): + if tok_sum[s] == 0: + continue + R = math.exp(lrn_sum[s] / tok_sum[s]) + A = adv_sum[s] / tok_sum[s] + R_clipped = max(1.0 - eps_lo, min(1.0 + eps_hi, R)) + surr1 = R * A + surr2 = R_clipped * A + loss += -min(surr1, surr2) * tok_sum[s] + return loss / divisor + + +# --------------------------------------------------------------------------- +# Test 1: single segment +# --------------------------------------------------------------------------- + + +def test_single_segment(): + torch.manual_seed(0) + n_tok, vocab = 8, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = ( + torch.log_softmax(torch.randn(n_tok, vocab, device=device), dim=-1) + .gather(-1, target.unsqueeze(-1)) + .squeeze(-1) + ) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 2: multi-segment packed +# --------------------------------------------------------------------------- + + +def test_multi_segment_packed(): + torch.manual_seed(1) + # 3 segments of lengths [5, 7, 4] + segs = [5, 7, 4] + n_tok = sum(segs) + vocab = 32 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = ( + torch.log_softmax(torch.randn(n_tok, vocab, device=device), dim=-1) + .gather(-1, target.unsqueeze(-1)) + .squeeze(-1) + ) + doc_idx = torch.cat([torch.full((l,), i, dtype=torch.long) for i, l in enumerate(segs)]).to(device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol * 3, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 3: GRPO vs GSPO equivalence when all tokens in a segment have ratio=1 +# --------------------------------------------------------------------------- + + +def test_ratio_one_matches_grpo(): + """When new == old log-probs (ratio=1 everywhere), GRPO and GSPO losses match.""" + torch.manual_seed(2) + n_tok, vocab = 12, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + # Set old log probs equal to new log probs for ratio=1 + old_log_probs = torch.log_softmax(logits.float(), dim=-1).gather(-1, target.unsqueeze(-1)).squeeze(-1).detach() + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_grpo, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_log_probs, divisor=divisor) + loss_gspo, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + # At ratio=1, GRPO loss = sum_t -A_t * mask_t / divisor (no clipping) + # GSPO loss = sum_s tok_s * -A_s / divisor (weighted per segment) + # For a single segment: GSPO = -mean(A) * N / divisor = same total + assert abs(loss_grpo.item() - loss_gspo.item()) < atol, f"grpo={loss_grpo.item()}, gspo={loss_gspo.item()}" + + +# --------------------------------------------------------------------------- +# Test 4: segment-level clipping (GSPO clips whole segment, not per-token) +# --------------------------------------------------------------------------- + + +def test_segment_level_clipping(): + """ + Construct a segment where per-token ratios straddle the clip boundary (some high, some low), + but the geometric mean ratio is in-range. GSPO should NOT clip; GRPO should clip some tokens. + """ + torch.manual_seed(3) + vocab = 8 + # 4 tokens, alternating log_ratio +0.5 and -0.5 → mean = 0 → R = exp(0) = 1.0 (in range) + n_tok = 4 + target = torch.zeros(n_tok, dtype=torch.long, device=device) + advantages = torch.ones(n_tok, device=device) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + + # Build logits such that new_log_probs - old_log_probs alternates +0.4 and -0.4 + # Use constant logits; set old_log_probs manually + logits = torch.zeros(n_tok, vocab, device=device) + old_log_probs = torch.tensor([0.4, -0.4, 0.4, -0.4], device=device) # per-token log_ratio = 0 - old + + eps = 0.2 + divisor = float(n_tok) + loss_gspo, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + epsilon_low=eps, + epsilon_high=eps, + divisor=divisor, + sdp_group=None, + ) + + # GSPO: mean log_ratio = mean of (log_softmax(0)[0] - old_log_probs) + # R = exp(mean), A=1.0 + # As long as R is in [1-eps, 1+eps], loss = -R * 1 * 4 / 4 = -R + new_log_probs = torch.log_softmax(logits.float(), dim=-1)[:, 0] + log_ratios = new_log_probs - old_log_probs + mean_log_ratio = log_ratios.mean().item() + R = math.exp(mean_log_ratio) + expected = -R # unclipped, weight 4/divisor = 1 + assert abs(loss_gspo.item() - expected) < atol, f"gspo={loss_gspo.item()}, expected={expected}" + + +# --------------------------------------------------------------------------- +# Test 5: masked tokens don't contribute +# --------------------------------------------------------------------------- + + +def test_masked_tokens(): + torch.manual_seed(4) + n_tok, vocab = 10, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + target[3] = -100 # mask token 3 + target[7] = -100 # mask token 7 + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 6: SDP mock — split tokens across 2 "ranks", verify correctness +# --------------------------------------------------------------------------- + + +def test_sdp_mock(): + """ + Simulate SDP=2: split tokens in half, compute per-rank scatter_add, manually all-reduce, + then verify the combined sums match the full-batch computation. + """ + torch.manual_seed(5) + segs = [6, 5, 7] # 3 segments + n_tok = sum(segs) + vocab = 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + doc_idx = torch.cat([torch.full((l,), i, dtype=torch.long) for i, l in enumerate(segs)]).to(device) + divisor = float(n_tok) + + # Full-batch reference loss + loss_full, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + + # Simulate SDP=2: split at midpoint + mid = n_tok // 2 + loss_r0_only, _, _ = fused_gspo_loss_forward_backward( + logits[:mid], + target[:mid], + advantages[:mid], + old_log_probs[:mid], + doc_idx[:mid], + divisor=divisor, + sdp_group=None, + ) + loss_r1_only, _, _ = fused_gspo_loss_forward_backward( + logits[mid:], + target[mid:], + advantages[mid:], + old_log_probs[mid:], + doc_idx[mid:], + divisor=divisor, + sdp_group=None, + ) + # These individual ranks do NOT give the right answer (segments are split) + # But the full-batch result should match the reference + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_full.item() - loss_ref) < atol * 3, f"full={loss_full.item()}, ref={loss_ref}" + + # When sdp_group is None but we manually pre-sum, the result should also match + # (This conceptually validates the all-reduce logic without actual distributed calls) + log_softmax_full = torch.log_softmax(logits.float(), dim=-1) + new_lp_full = log_softmax_full.gather(-1, (target * (target >= 0)).unsqueeze(-1)).squeeze(-1) + log_ratio_full = (new_lp_full - old_log_probs.float()) * (target >= 0).float() + + n_segs = 3 + lrn_r0 = torch.zeros(n_segs) + adv_r0 = torch.zeros(n_segs) + tok_r0 = torch.zeros(n_segs) + lrn_r1 = torch.zeros(n_segs) + adv_r1 = torch.zeros(n_segs) + tok_r1 = torch.zeros(n_segs) + for i in range(mid): + if target[i] >= 0: + s = doc_idx[i].item() + lrn_r0[s] += log_ratio_full[i].item() + adv_r0[s] += advantages[i].item() + tok_r0[s] += 1 + for i in range(mid, n_tok): + if target[i] >= 0: + s = doc_idx[i].item() + lrn_r1[s] += log_ratio_full[i].item() + adv_r1[s] += advantages[i].item() + tok_r1[s] += 1 + + # Manually all-reduce (SUM) + lrn_global = lrn_r0 + lrn_r1 + adv_global = adv_r0 + adv_r1 + tok_global = tok_r0 + tok_r1 + + loss_manual = 0.0 + for s in range(n_segs): + if tok_global[s] == 0: + continue + R = math.exp(lrn_global[s] / tok_global[s]) + A = adv_global[s] / tok_global[s] + R_c = max(1 - 0.2, min(1 + 0.2, R)) + loss_manual += -min(R * A, R_c * A) * tok_global[s] + loss_manual /= divisor + + assert abs(loss_full.item() - loss_manual) < atol * 3, f"full={loss_full.item()}, manual={loss_manual}" + + +# --------------------------------------------------------------------------- +# Test 7: gradient correctness via finite differences +# --------------------------------------------------------------------------- + + +def test_gradient_finite_diff(): + torch.manual_seed(6) + n_tok, vocab = 6, 8 + logits = torch.randn(n_tok, vocab, dtype=torch.float64) + target = torch.randint(0, vocab, (n_tok,)) + advantages = torch.randn(n_tok, dtype=torch.float64) + old_log_probs = torch.randn(n_tok, dtype=torch.float64) + doc_idx = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) + divisor = float(n_tok) + eps = 1e-5 + + grad_logits = torch.zeros_like(logits) + _, grad_out, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + grad_logits=grad_logits, + grad_output=1.0, + divisor=divisor, + sdp_group=None, + ) + + # Finite-difference gradient for one entry + i, k = 2, 3 + logits_p = logits.clone() + logits_p[i, k] += eps + logits_m = logits.clone() + logits_m[i, k] -= eps + loss_p, _, _ = fused_gspo_loss_forward_backward( + logits_p, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_m, _, _ = fused_gspo_loss_forward_backward( + logits_m, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + fd_grad = (loss_p.item() - loss_m.item()) / (2 * eps) + + assert abs(grad_out[i, k].item() - fd_grad) < 1e-4, f"analytical={grad_out[i, k].item():.6f}, fd={fd_grad:.6f}" + + +# --------------------------------------------------------------------------- +# Test 8: extra metrics unchanged by policy_loss choice +# --------------------------------------------------------------------------- + + +def test_extra_metrics_are_per_token(): + """pg_metrics are per-token regardless of GSPO/GRPO — computed from token-level ratios.""" + from fast_llm.layers.language_model.loss.pg_metrics import compute_policy_gradient_metrics + + torch.manual_seed(7) + n_tok, vocab = 10, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + label_counts = torch.full((n_tok,), n_tok, dtype=torch.float32, device=device) + + metrics = compute_policy_gradient_metrics( + logits, + target, + old_log_probs, + advantages, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=1.0, + vocab_parallel_group=None, + ) + # Sanity: metrics are finite scalars + for attr in ("old_logprobs", "ratio_new_old", "kl_new_old", "advantage"): + val = getattr(metrics, attr) + assert val.isfinite(), f"{attr} is not finite: {val}" From fecc978d8229005a2cb55040f1e823e8b63ff5c1 Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 28 Apr 2026 11:06:04 +0000 Subject: [PATCH 2/7] schedule: add rollouts_per_step to auto-set depth_first_micro_batches MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel × breadth_first_micro_batches) before sub-configs are created (and frozen). Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1 each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8 gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally. YAML usage: schedule: rollouts_per_step: 1024 # replaces manual depth_first_micro_batches model: distributed: data_parallel: 8 # used for the division --- fast_llm/engine/schedule/config.py | 9 +++++++++ fast_llm/engine/training/config.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+) diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 29720b90b..40e65fb60 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -21,6 +21,15 @@ class ScheduleConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + rollouts_per_step: int = Field( + default=0, + desc="When >0, automatically sets depth_first_micro_batches = rollouts_per_step // " + "(batch_data_parallel × breadth_first_micro_batches). " + "Matches DeepSpeed's gradient_accumulation_passes semantics for RL training " + "where each microbatch contains one rollout. 0 = use depth_first_micro_batches as-is.", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) breadth_first_micro_batches: int = Field( default=1, desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.", diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index bece3cb49..78c1062fa 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -372,6 +372,21 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.feature, ) + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + # Derive depth_first_micro_batches from rollouts_per_step before sub-configs are created. + schedule = default.get("schedule", {}) + rollouts = schedule.get("rollouts_per_step", 0) + if rollouts > 0: + distributed = default.get("model", {}).get("distributed", {}) + dp = distributed.get("data_parallel", 1) + sdp = max(distributed.get("sequence_data_parallel", 1), 1) + batch_dp = max(dp // sdp, 1) + bfmb = schedule.get("breadth_first_micro_batches", 1) + depth_first = rollouts // (batch_dp * bfmb) + default = {**default, "schedule": {**schedule, "depth_first_micro_batches": depth_first}} + return super()._from_dict(default, strict) + def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): From 7d8ec0ca1322115b93bb691bb6779c4734e136cc Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 28 Apr 2026 12:18:20 +0000 Subject: [PATCH 3/7] grpo: dynamic docs_per_step accumulation and normalize_by_documents MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first is now determined at runtime rather than statically in _from_dict - Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs} properties so per-step schedules share the same config object as the runner - Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time, all-reduces doc count per microbatch, stops when global total ≥ docs_per_step, then resets num_documents_in_batch to the step total on all inputs - Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with _depth_first_override=N//breadth_first_micro_batches - Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True both GRPO and GSPO paths divide by num_documents_in_batch instead of num_labels_in_batch (matches DeepSpeed's per-rollout normalization) - Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor scaling, normalize_by_documents layer routing, Schedule._eff_* properties, and _prefetch_to_doc_target accumulation logic --- fast_llm/engine/schedule/config.py | 11 +- fast_llm/engine/schedule/runner.py | 5 +- fast_llm/engine/schedule/schedule.py | 38 ++- fast_llm/engine/training/config.py | 15 - fast_llm/engine/training/trainer.py | 62 +++- fast_llm/layers/language_model/loss/config.py | 8 + fast_llm/layers/language_model/loss/grpo.py | 9 +- tests/layers/test_docs_per_step.py | 322 ++++++++++++++++++ 8 files changed, 426 insertions(+), 44 deletions(-) create mode 100644 tests/layers/test_docs_per_step.py diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 40e65fb60..2920c1334 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -21,12 +21,13 @@ class ScheduleConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) - rollouts_per_step: int = Field( + docs_per_step: int = Field( default=0, - desc="When >0, automatically sets depth_first_micro_batches = rollouts_per_step // " - "(batch_data_parallel × breadth_first_micro_batches). " - "Matches DeepSpeed's gradient_accumulation_passes semantics for RL training " - "where each microbatch contains one rollout. 0 = use depth_first_micro_batches as-is.", + desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. " + "When >0, each training step dynamically accumulates microbatches until the globally all-reduced " + "document count reaches this value, then triggers the optimizer step. " + "depth_first_micro_batches is ignored when this is set. " + "0 = use depth_first_micro_batches as-is (fixed microbatch count per step).", hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index b2e212946..128b95e8e 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -320,7 +320,8 @@ def _preprocess_data( if context.schedule.phase.is_training else None ) - model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] + n_micro_batches = context.schedule._eff_sequential_micro_batches + model_inputs = [next(data_iterator) for _ in range(n_micro_batches)] model_inputs[0][0].share_batch_data( [model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed ) @@ -336,7 +337,7 @@ def _preprocess_data( extra_kwargs={ "grad_output": grad_output, "micro_batch": micro_batch, - "num_micro_batches": self._config.sequential_micro_batches, + "num_micro_batches": n_micro_batches, "micro_batch_splits": self._config.micro_batch_splits, }, ) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 6f7bf1d95..845b5df82 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -115,15 +115,17 @@ def __init__( batch_meta: list[ModelInput], distributed_config: DistributedConfig, phase: PhaseType, + _depth_first_override: int | None = None, ): super().__init__(config) + self._depth_first_override = _depth_first_override self._multi_stage = multi_stage self._distributed_config = distributed_config self._num_stages = len(self._multi_stage.stages) self._phase = phase self._is_training = self._phase.is_training - if self._config.num_inputs < self._distributed_config.pipeline_parallel: + if self._eff_num_inputs < self._distributed_config.pipeline_parallel: warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. @@ -155,9 +157,25 @@ def __init__( def phase(self) -> PhaseType: return self._phase + @property + def _eff_depth_first(self) -> int: + return ( + self._depth_first_override + if self._depth_first_override is not None + else self._config.depth_first_micro_batches + ) + + @property + def _eff_sequential_micro_batches(self) -> int: + return self._eff_depth_first * self._config.breadth_first_micro_batches + + @property + def _eff_num_inputs(self) -> int: + return self._eff_sequential_micro_batches * self._config.micro_batch_splits + @property def samples_per_batch(self) -> int: - return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel + return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]: return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank]) @@ -189,7 +207,7 @@ def _create_index(self) -> None: Assert.in_range( step.index, 0, - self._config.num_inputs, + self._eff_num_inputs, ) Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i @@ -205,7 +223,7 @@ def _create_index(self) -> None: Assert.custom(all, self._device_steps) # Consistency checks step_map = self._step_map.copy() - for data_index in range(self._config.num_inputs): + for data_index in range(self._eff_num_inputs): for type_ in (StepType.forward, StepType.backward): for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages): assert ( @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]: first_grad_stage += 1 else: first_grad_stage = self._num_stages - for depth_first_micro_batch in range(self._config.depth_first_micro_batches): + for depth_first_micro_batch in range(self._eff_depth_first): for stage in range(self._num_stages): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in range(self._config.micro_batch_splits): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]: for stage in reversed(range(first_grad_stage, self._num_stages)): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in reversed(range(self._config.micro_batch_splits)): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, diff --git a/fast_llm/engine/training/config.py b/fast_llm/engine/training/config.py index 78c1062fa..bece3cb49 100644 --- a/fast_llm/engine/training/config.py +++ b/fast_llm/engine/training/config.py @@ -372,21 +372,6 @@ class TrainerConfig(PretrainedFastLLMModelConfig, ExperimentConfig): hint=FieldHint.feature, ) - @classmethod - def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: - # Derive depth_first_micro_batches from rollouts_per_step before sub-configs are created. - schedule = default.get("schedule", {}) - rollouts = schedule.get("rollouts_per_step", 0) - if rollouts > 0: - distributed = default.get("model", {}).get("distributed", {}) - dp = distributed.get("data_parallel", 1) - sdp = max(distributed.get("sequence_data_parallel", 1), 1) - batch_dp = max(dp // sdp, 1) - bfmb = schedule.get("breadth_first_micro_batches", 1) - depth_first = rollouts // (batch_dp * bfmb) - default = {**default, "schedule": {**schedule, "depth_first_micro_batches": depth_first}} - return super()._from_dict(default, strict) - def _validate(self) -> None: self.training.export.setup(self.model) for reference_model in self.reference_models.values(): diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 00cf2fa0d..5c8bc0b89 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -115,10 +115,13 @@ def setup(self, distributed: Distributed, run: Run) -> None: preprocessing_config = self._multi_stage.get_preprocessing_config( PhaseType.training, self._config.schedule.micro_batch_splits ) + self._preprocessing_config = preprocessing_config + self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size) + self._schedule_cache: dict[int, Schedule] = {} self._schedule = Schedule( config=self._config.schedule, multi_stage=self._multi_stage, - batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size), + batch_meta=self._single_mb_meta, distributed_config=self._config.model.distributed, phase=PhaseType.training, ) @@ -140,6 +143,41 @@ def setup(self, distributed: Distributed, run: Run) -> None: self._is_setup = True + def _get_or_build_schedule(self, n_microbatches: int) -> Schedule: + if n_microbatches not in self._schedule_cache: + bfmb = self._config.schedule.breadth_first_micro_batches + depth_first = n_microbatches // bfmb + self._schedule_cache[n_microbatches] = Schedule( + config=self._config.schedule, + multi_stage=self._multi_stage, + batch_meta=self._single_mb_meta, + distributed_config=self._config.model.distributed, + phase=PhaseType.training, + _depth_first_override=depth_first, + ) + return self._schedule_cache[n_microbatches] + + def _prefetch_to_doc_target(self, data_iterator) -> list: + target = self._config.schedule.docs_per_step + bfmb = self._config.schedule.breadth_first_micro_batches + buffer = [] + total_docs = 0 + while total_docs < target: + mb = next(data_iterator) + mb[0].share_batch_data(mb, self._distributed) + total_docs += mb[0].num_documents_in_batch + buffer.append(mb) + Assert.eq( + len(buffer) % bfmb, + 0, + msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}", + ) + # Reset num_documents_in_batch to the step total on all microbatches + for mb in buffer: + for mi in mb: + mi.num_documents_in_batch = total_docs + return buffer + @abc.abstractmethod def _get_data(self) -> Data: pass @@ -220,12 +258,22 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Data loader hates getting all micro-batches at once. # (Also preprocessing adds overhead) - reduced_losses, update_successful, train_metrics = self._runner.run_step( - train_iterator, - self._schedule, - iteration=self._completed_steps, - return_metrics=is_logging, - ) + if self._config.schedule.docs_per_step > 0: + buffer = self._prefetch_to_doc_target(train_iterator) + step_schedule = self._get_or_build_schedule(len(buffer)) + reduced_losses, update_successful, train_metrics = self._runner.run_step( + iter(buffer), + step_schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) + else: + reduced_losses, update_successful, train_metrics = self._runner.run_step( + train_iterator, + self._schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) # Advanced, skipped, and Nan iterations. if update_successful: diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index a5e34dd3e..46288fda9 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -225,6 +225,14 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): desc="Batch chunk size for chunked entropy computation. Memory per chunk ∝ chunk_size × vocab_local.", hint=FieldHint.expert, ) + normalize_by_documents: bool = Field( + default=False, + desc="Normalize the policy-gradient loss by the number of documents (rollouts) in the step " + "rather than the number of tokens. Matches DeepSpeed's normalization where each token's " + "loss is divided by config.batch_size (total rollout count). " + "Set to True when using docs_per_step for full DS parity.", + hint=FieldHint.feature, + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 2f9c190e6..b2a619ec2 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -21,6 +21,11 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: + divisor = ( + kwargs[LanguageModelKwargs.num_documents_in_batch] + if self._config.normalize_by_documents + else self._get_label_count(kwargs) + ) if self._config.policy_loss == "gspo": loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( logits, @@ -39,7 +44,7 @@ def _forward_backward( if losses is None else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), - divisor=self._get_label_count(kwargs), + divisor=divisor, sdp_group=self._sdp_dim.group if self._sdp_active else None, ) else: @@ -65,7 +70,7 @@ def _forward_backward( if losses is None else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), - divisor=self._get_label_count(kwargs), + divisor=divisor, ) if new_logprobs_mean is not None: diff --git a/tests/layers/test_docs_per_step.py b/tests/layers/test_docs_per_step.py new file mode 100644 index 000000000..b57c25057 --- /dev/null +++ b/tests/layers/test_docs_per_step.py @@ -0,0 +1,322 @@ +""" +Unit tests for docs_per_step / normalize_by_documents features. + +Covers: + 1. Divisor scaling in fused_grpo_loss_forward_backward and fused_gspo_loss_forward_backward + 2. normalize_by_documents flag in LanguageModelGRPOLoss (GRPO and GSPO policy_loss) + 3. Schedule._eff_depth_first / _eff_sequential_micro_batches / _eff_num_inputs properties + 4. Trainer._prefetch_to_doc_target accumulation logic +""" + +import dataclasses +import types + +import pytest +import torch + +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.grpo import ( + fused_grpo_loss_forward_backward, + fused_gspo_loss_forward_backward, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" +_atol = 1e-4 if device == "cuda" else 1e-5 + + +# --------------------------------------------------------------------------- +# 1. Divisor-scaling correctness in raw kernels +# --------------------------------------------------------------------------- + + +def test_grpo_divisor_scales_loss(): + """Halving the divisor should double the loss.""" + torch.manual_seed(10) + n_tok, vocab = 16, 32 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + + d1 = float(n_tok) + d2 = float(n_tok) * 2 + + loss1, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d1) + loss2, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d2) + + assert ( + abs(loss1.item() - 2.0 * loss2.item()) < _atol * 10 + ), f"Expected loss(d1) ≈ 2*loss(d2), got {loss1.item():.6f} vs {2*loss2.item():.6f}" + + +def test_gspo_divisor_scales_loss(): + """Halving the divisor should double the GSPO loss.""" + torch.manual_seed(11) + n_tok, vocab = 12, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + doc_idx = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], dtype=torch.long, device=device) + + d1 = float(n_tok) + d2 = float(n_tok) * 2 + + loss1, _, _ = fused_gspo_loss_forward_backward( + logits, target, advantages, old_lp, doc_idx, divisor=d1, sdp_group=None + ) + loss2, _, _ = fused_gspo_loss_forward_backward( + logits, target, advantages, old_lp, doc_idx, divisor=d2, sdp_group=None + ) + + assert ( + abs(loss1.item() - 2.0 * loss2.item()) < _atol * 10 + ), f"Expected loss(d1) ≈ 2*loss(d2), got {loss1.item():.6f} vs {2*loss2.item():.6f}" + + +# --------------------------------------------------------------------------- +# 2. normalize_by_documents flag in LanguageModelGRPOLoss +# --------------------------------------------------------------------------- + + +def _make_grpo_loss(normalize_by_documents: bool, policy_loss: str = "grpo"): + """Instantiate a LanguageModelGRPOLoss with minimal (single-GPU) DistributedConfig.""" + from fast_llm.engine.distributed.config import DistributedConfig + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + + dist_cfg = DistributedConfig() + cfg = LanguageModelGRPOLossConfig( + normalize_by_documents=normalize_by_documents, + policy_loss=policy_loss, + ) + return LanguageModelGRPOLoss(cfg, dist_cfg, name="grpo", prediction_distance=1, prediction_heads=1) + + +def _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs): + """Build the kwargs dict expected by LanguageModelGRPOLoss._forward_backward.""" + return { + LanguageModelLossKwargs.labels: [target], + LanguageModelLossKwargs.advantages: [advantages], + LanguageModelLossKwargs.old_log_probabilities: [old_lp], + LanguageModelLossKwargs.label_counts: [torch.full_like(target, n_labels, dtype=torch.int32)], + LanguageModelKwargs.num_labels_in_batch: [n_labels], + LanguageModelKwargs.num_documents_in_batch: n_docs, + LanguageModelKwargs.document_index: [doc_idx], + } + + +def test_normalize_by_documents_grpo(): + """normalize_by_documents=True → divisor=n_docs; False → divisor=n_labels. + + With n_docs ≠ n_labels, loss ratio must equal n_labels / n_docs. + """ + torch.manual_seed(20) + n_tok, vocab = 12, 16 + n_docs, n_labels = 3, n_tok + + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + + kwargs = _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs) + + loss_by_tokens, _ = _make_grpo_loss(normalize_by_documents=False)._forward_backward(logits, kwargs) + loss_by_docs, _ = _make_grpo_loss(normalize_by_documents=True)._forward_backward(logits, kwargs) + + expected_ratio = float(n_labels) / float(n_docs) + actual_ratio = loss_by_docs.item() / loss_by_tokens.item() + assert ( + abs(actual_ratio - expected_ratio) < 1e-4 + ), f"Expected loss_docs/loss_tokens ≈ {expected_ratio:.4f}, got {actual_ratio:.4f}" + + +def test_normalize_by_documents_gspo(): + """Same test for GSPO policy_loss.""" + torch.manual_seed(21) + n_tok, vocab = 12, 16 + n_docs, n_labels = 3, n_tok + + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + # 3 equal segments → n_docs=3 + doc_idx = torch.cat([torch.full((n_tok // n_docs,), i, dtype=torch.long) for i in range(n_docs)]).to(device) + + kwargs = _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs) + + loss_by_tokens, _ = _make_grpo_loss(normalize_by_documents=False, policy_loss="gspo")._forward_backward( + logits, kwargs + ) + loss_by_docs, _ = _make_grpo_loss(normalize_by_documents=True, policy_loss="gspo")._forward_backward( + logits, kwargs + ) + + expected_ratio = float(n_labels) / float(n_docs) + actual_ratio = loss_by_docs.item() / loss_by_tokens.item() + assert ( + abs(actual_ratio - expected_ratio) < 1e-4 + ), f"Expected loss_docs/loss_tokens ≈ {expected_ratio:.4f}, got {actual_ratio:.4f}" + + +# --------------------------------------------------------------------------- +# 3. Schedule._eff_* properties +# --------------------------------------------------------------------------- + + +def _make_bare_schedule(depth_first: int, breadth_first: int, splits: int, override: int | None) -> Schedule: + """Create a Schedule with __init__ bypassed to test the _eff_* properties only.""" + config = ScheduleConfig( + depth_first_micro_batches=depth_first, + breadth_first_micro_batches=breadth_first, + micro_batch_splits=splits, + ) + sched = object.__new__(Schedule) + # Minimal attributes used by the three _eff_* properties. + object.__setattr__(sched, "_config", config) + object.__setattr__(sched, "_depth_first_override", override) + # samples_per_batch also needs _distributed_config.batch_data_parallel + fake_distributed = types.SimpleNamespace(batch_data_parallel=1) + object.__setattr__(sched, "_distributed_config", fake_distributed) + return sched + + +def test_schedule_eff_properties_no_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=None) + assert sched._eff_depth_first == 4 + assert sched._eff_sequential_micro_batches == 8 # 4 * 2 + assert sched._eff_num_inputs == 24 # 8 * 3 + assert sched.samples_per_batch == 8 # 8 * dp=1 + + +def test_schedule_eff_properties_with_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=7) + assert sched._eff_depth_first == 7 # override wins + assert sched._eff_sequential_micro_batches == 14 # 7 * 2 + assert sched._eff_num_inputs == 42 # 14 * 3 + assert sched.samples_per_batch == 14 # 14 * dp=1 + + +def test_schedule_eff_properties_override_equals_config(): + """Override equal to config value → same result as no override.""" + sched_no = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=None) + sched_yes = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=3) + assert sched_no._eff_depth_first == sched_yes._eff_depth_first + assert sched_no._eff_sequential_micro_batches == sched_yes._eff_sequential_micro_batches + assert sched_no._eff_num_inputs == sched_yes._eff_num_inputs + + +def test_schedule_samples_per_batch_uses_eff(): + """samples_per_batch should scale with _eff_sequential, not config.sequential.""" + sched = _make_bare_schedule(depth_first=2, breadth_first=2, splits=1, override=5) + # Config says depth_first=2 → sequential=4; override=5 → eff_sequential=10 + assert sched._eff_sequential_micro_batches == 10 + assert sched.samples_per_batch == 10 # dp=1 + + +# --------------------------------------------------------------------------- +# 4. _prefetch_to_doc_target accumulation logic +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _FakeMicrobatch: + """Stub for a single split of one microbatch.""" + + num_documents: int + num_documents_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, inputs, distributed): + """Mimic TokenModelInput.share_batch_data with group=None (single process).""" + if inputs[0].num_documents_in_batch is None: + total = sum(inp.num_documents for inp in inputs) + for inp in inputs: + inp.num_documents_in_batch = total + + +def _fake_iterator(doc_counts: list[int]): + """Yield [_FakeMicrobatch(n)] for each n in doc_counts.""" + for n in doc_counts: + yield [_FakeMicrobatch(num_documents=n)] + + +class _StubTrainer: + """Concrete stub that exposes only the interface _prefetch_to_doc_target needs.""" + + # Borrow the method directly so it runs against this stub's attributes. + from fast_llm.engine.training.trainer import Trainer as _Trainer + + _prefetch_to_doc_target = _Trainer._prefetch_to_doc_target + + +def _make_fake_trainer(docs_per_step: int, bfmb: int = 1): + """Create a _StubTrainer with the attributes _prefetch_to_doc_target reads.""" + schedule_cfg = types.SimpleNamespace( + docs_per_step=docs_per_step, + breadth_first_micro_batches=bfmb, + ) + config = types.SimpleNamespace(schedule=schedule_cfg) + distributed = types.SimpleNamespace(batch_data_group=None) + + trainer = _StubTrainer() + trainer._config = config + trainer._distributed = distributed + return trainer + + +def test_prefetch_stops_at_target(): + """Buffer should stop growing once cumulative docs ≥ docs_per_step.""" + trainer = _make_fake_trainer(docs_per_step=6, bfmb=1) + # Each microbatch has 2 docs; need ≥6 → expect 3 microbatches + it = _fake_iterator([2, 2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + + assert len(buffer) == 3, f"Expected 3 microbatches, got {len(buffer)}" + + +def test_prefetch_resets_num_documents_in_batch(): + """After the call, every microbatch input has num_documents_in_batch = step total.""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + # 3 docs, 3 docs → total=6 (overshoots 5, stops after 2nd) + it = _fake_iterator([3, 3, 3]) + buffer = trainer._prefetch_to_doc_target(it) + + step_total = sum(mb[0].num_documents for mb in buffer) + for mb in buffer: + for mi in mb: + assert ( + mi.num_documents_in_batch == step_total + ), f"Expected num_documents_in_batch={step_total}, got {mi.num_documents_in_batch}" + + +def test_prefetch_overshoot_is_included(): + """A microbatch that pushes the total over the target IS included (not dropped).""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + it = _fake_iterator([4, 4]) # 4 < 5, then 8 ≥ 5 → 2 microbatches + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2 + assert buffer[-1][0].num_documents_in_batch == 8 # step total = 4+4 + + +def test_prefetch_divisibility_check(): + """Raises when fetched count is not divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # Each microbatch has 5 docs → only 1 mb needed, but 1 % 2 != 0 + it = _fake_iterator([5, 5, 5]) + with pytest.raises(Exception): + trainer._prefetch_to_doc_target(it) + + +def test_prefetch_exact_divisibility(): + """No error when fetched count is exactly divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # 2 docs each → need ≥4 → fetch 2 microbatches → 2 % 2 == 0 + it = _fake_iterator([2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2 From 014ba59993d080b50a01ddb2c15a82e3012aa886 Mon Sep 17 00:00:00 2001 From: bigximik Date: Wed, 29 Apr 2026 08:02:14 +0000 Subject: [PATCH 4/7] grpo: temperature scaling for IS ratio parity with actor sampling Add temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probs are computed at the same temperature as the stored old log-probs, so the IS ratio starts near 1.0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted for logits_scale_factor at all three callsites in _forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default temperature=1.0 preserves existing behaviour exactly. --- fast_llm/layers/language_model/loss/config.py | 7 +++++++ fast_llm/layers/language_model/loss/grpo.py | 10 +++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 46288fda9..0634b15f7 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -233,6 +233,13 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): "Set to True when using docs_per_step for full DS parity.", hint=FieldHint.feature, ) + temperature: float = Field( + default=1.0, + desc="Temperature applied to logits before computing new log-probabilities. " + "Set to match the sampling temperature used by the actor (e.g. 0.7) so that " + "new and old log-probs are in the same scale and the IS ratio starts near 1.", + valid=check_field(Assert.gt, 0), + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index b2a619ec2..8472580f8 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -38,7 +38,7 @@ def _forward_backward( group=self._parallel_dim.group if self._vocab_parallel else None, epsilon_low=self._config.epsilon_low, epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._effective_logits_scale, num_labels_in_seq=( None if losses is None @@ -64,7 +64,7 @@ def _forward_backward( group=self._parallel_dim.group if self._vocab_parallel else None, epsilon_low=self._config.epsilon_low, epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._logits_scale_factor, + logits_scale_factor=self._effective_logits_scale, num_labels_in_seq=( None if losses is None @@ -101,7 +101,7 @@ def _register_pg_metrics( self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index), self._config.epsilon_low, self._config.epsilon_high, - self._logits_scale_factor, + self._effective_logits_scale, vocab_parallel_group=self._parallel_dim.group if self._vocab_parallel else None, compute_entropy=self._config.compute_entropy_metric, entropy_chunk_size=self._config.entropy_chunk_size, @@ -173,6 +173,10 @@ def get_preprocessing_config( config["return_document_index"] = True return config + @functools.cached_property + def _effective_logits_scale(self) -> float: + return self._logits_scale_factor / self._config.temperature + @functools.cached_property def _logprob_metric_name(self) -> str: return f"{self._name}_new_logprobs" From d8cb9ef5577ccf85e24c0f368c9ff1ba5b451400 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 4 May 2026 07:14:38 +0000 Subject: [PATCH 5/7] head: fp32_lm_head flag to match vLLM bf16_last_layer_fp32 precision Add fp32_lm_head to LanguageModelHeadConfig. When enabled, input hidden states and output_weights are cast to float32 before the lm_head linear, producing FP32 logits. This matches vLLM's bf16_last_layer_fp32 quantization (pipelinerl/vllm_quantization.py) and the DeepSpeed trainer's apply_fp32_lm_head() patch, so new_logprobs and old_logprobs are computed at the same numerical precision and the IS ratio starts near 1.0 at init. The gradient flowing back through the linear is cast to the original input dtype (bf16) before returning, keeping the transformer backward pass in its native dtype. --- fast_llm/layers/language_model/config.py | 7 +++++++ fast_llm/layers/language_model/head.py | 20 ++++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 1de722cae..0d48e92cb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -120,6 +120,13 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + fp32_lm_head: bool = Field( + default=False, + desc="Upcast input and weight to float32 before the lm_head linear. " + "Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs " + "are computed at the same numerical precision, keeping the IS ratio near 1 at init.", + hint=FieldHint.feature, + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 95be18035..87da4fbbd 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -242,9 +242,16 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if self._config.fp32_lm_head: + input_dtype = input_.dtype + input_ = input_.to(torch.float32) + weight = self.output_weights.to(torch.float32) + else: + weight = self.output_weights + logits, context = output_parallel_linear_forward( input_=input_, - weight=self.output_weights, + weight=weight, bias=None, group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -273,9 +280,14 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) - return sum(losses_) if losses_ else None, ( - output_parallel_linear_backward(grad, context) if self.training else None - ) + if not self.training or grad is None: + return sum(losses_) if losses_ else None, None + + input_grad = output_parallel_linear_backward(grad, context) + if self._config.fp32_lm_head: + input_grad = input_grad.to(input_dtype) + + return sum(losses_) if losses_ else None, input_grad def get_loss_definitions(self) -> list[LossDef]: return [ From 0f90f20b1bd8793ebfabb7c0aeaf37998d9e1177 Mon Sep 17 00:00:00 2001 From: bigximik Date: Mon, 4 May 2026 07:42:52 +0000 Subject: [PATCH 6/7] head: fix fp32_lm_head gradient flow via detach + manual weight grad accumulation Detaching the FP32 weight copy (requires_grad=False) prevents output_parallel_linear_backward from trying to write to a non-existent grad_buffer on the copy. Weight grad is then computed explicitly from the FP32 matmul and accumulated into the original BF16 param's grad_buffer via accumulate_gradient, restoring the correct FSDP gradient contract. --- fast_llm/layers/language_model/head.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 87da4fbbd..31addb34c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,7 +22,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.tensor import TensorMeta +from fast_llm.tensor import TensorMeta, accumulate_gradient from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -245,7 +245,8 @@ def _logits_loss_forward_backward_partial( if self._config.fp32_lm_head: input_dtype = input_.dtype input_ = input_.to(torch.float32) - weight = self.output_weights.to(torch.float32) + # detach → requires_grad=False → output_parallel_linear_backward skips weight grad + weight = self.output_weights.detach().to(torch.float32) else: weight = self.output_weights @@ -285,6 +286,15 @@ def _logits_loss_forward_backward_partial( input_grad = output_parallel_linear_backward(grad, context) if self._config.fp32_lm_head: + # Weight grad was skipped because weight.requires_grad=False; accumulate manually. + # context: (input_, weight, bias, group, sequence_parallel, ...) + saved_input = context[0] + if context[4]: # sequence_parallel + from fast_llm.core.ops import gather_op + + saved_input = gather_op(saved_input, context[3], dim=0) + grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2)) + accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype)) input_grad = input_grad.to(input_dtype) return sum(losses_) if losses_ else None, input_grad From 557a3c4c1a4aea08049d467510693745834f10dd Mon Sep 17 00:00:00 2001 From: bigximik Date: Tue, 5 May 2026 13:45:31 +0000 Subject: [PATCH 7/7] grpo: decouple loss/gradient divisors and fix SDP loss double-counting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024× larger than DeepSpeed's for the equivalent loss, causing the default gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10 reward points slower than DS GSPO at the same step count. The lm_head_loss metric was also off — 1024× smaller than DS's rl/loss in the previous divisor=num_documents² formulation, then 2× too large from SDP doubling. Root cause analysis ------------------- DeepSpeed has TWO 1/batch_size factors with different sources: 1. Loss reported (rl/loss) uses /batch_size via tokens_weights = 1/batch_size (pipelinerl/finetune/rl/__init__.py:246). The reported `rl/loss = -1.7` value is the raw policy_loss_total, divided once by batch_size. 2. Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes from `scale_wrt_gas=True` in engine.backward() (deepspeed/runtime/engine.py:1995-1996) and `tensor.div_(world_sz)` in reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124). For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size = batch_size, so DS's effective gradient buffer factor is 1/batch_size² while the loss metric factor is 1/batch_size. Loss and gradient have asymmetric scaling. Fast-LLM's existing implementation used a single `divisor` for both loss and gradient. Worse, the data_parallel × grad_scale factor in grad_output (runner.py:318) cancels with FSDP's RS-AVG /world_size, structurally removing DS's /(gas × world_size) factor from the gradient. So fast-LLM's gradient buffer ended up at 1/batch_size while DS's was at 1/batch_size² — a ~batch_size = 1024× mismatch. Additionally, GSPO's SDP allreduce of lrn_sum/adv_sum/tok_sum makes both SDP ranks compute IDENTICAL per-segment loss values. When LossDef.reduce sums over the data_group (which includes SDP ranks), the loss metric is double-counted by sdp_size. The gradient buffer is NOT double-counted — each SDP rank contributes gradient from its own LOCAL tokens, with different contributions for different tokens of the same segment. Fixes ----- 1. Add a `grad_divisor` parameter to `fused_gspo_loss_forward_backward`, `fused_grpo_loss_forward_backward`, and `triton_grpo_loss_forward_backward`, defaulting to `divisor` (existing behavior). Allows the gradient to use a different divisor than the loss. 2. In `LanguageModelGRPOLoss._forward_backward`, when normalize_by_documents is True, set: loss divisor = num_documents_in_batch (matches DS rl/loss) gradient divisor = num_documents_in_batch² (matches DS grad_norm) This is independent of TP/PP/SDP/DP parallelism and microbatching schedule because batch_size is invariant under all of these. 3. In the GSPO path, divide the loss by sdp_size when sdp_group is active (`fused_gspo_loss_forward_backward`). This pre-cancels the SDP doubling that LossDef.reduce's SUM over data_group introduces. The gradient is unaffected — different SDP ranks naturally contribute gradient from different LOCAL token positions, no double-counting at any layer. Verification ------------ Tested on 7B math run with 4 nodes, GSPO, gradient_norm_clipping=0.3: Before fix | After fix | DS GSPO reference ------------------- | ------------------ | ------------------ step 1 grad_norm=141| step 1 grad_norm=0.135 | step 1 grad_norm=0.145 step 1 lm_head_loss | step 1 lm_head_loss | step 1 rl/loss = -13.7 | ~ -1.7 (sign varies | = -1.7 | per data sample) | clip_coeff=0.002 | clip_coeff=1.000 | no clipping at step 1 newlp at step 50 | newlp at step 50 | newlp at step 50 trapped at -0.17 | = -0.103 | = -0.105 newlp trajectory tracks DS step-by-step: step 1 within 3%, step 50 within 2%. Both systems show grad_norm spikes at the same training phase (steps 14-20) during warmup ramp-up — DS step 16 grad_norm=6.365 vs Fast-LLM 6.093. Files changed ------------- - fast_llm/layers/language_model/loss/grpo.py: - LanguageModelGRPOLoss._forward_backward: split divisor and grad_divisor based on normalize_by_documents flag, with detailed comments referencing the corresponding lines in DeepSpeed and PipelineRL. - fused_gspo_loss_forward_backward: add grad_divisor parameter; divide loss by sdp_size when sdp_group is active. - fused_grpo_loss_forward_backward: add grad_divisor parameter. - fast_llm/functional/triton/grpo_loss.py: - triton_grpo_loss_forward_backward: add grad_divisor parameter. --- fast_llm/functional/triton/grpo_loss.py | 5 ++- fast_llm/layers/language_model/loss/grpo.py | 50 ++++++++++++++++++--- 2 files changed, 47 insertions(+), 8 deletions(-) diff --git a/fast_llm/functional/triton/grpo_loss.py b/fast_llm/functional/triton/grpo_loss.py index 39d832ccd..709bbc73c 100644 --- a/fast_llm/functional/triton/grpo_loss.py +++ b/fast_llm/functional/triton/grpo_loss.py @@ -137,6 +137,7 @@ def triton_grpo_loss_forward_backward( logits_scale_factor: float = 1.0, num_labels_in_seq: torch.Tensor | None = None, divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) block_size: int | None = None, num_warps: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: @@ -148,6 +149,8 @@ def triton_grpo_loss_forward_backward( n_cols = logits.size(-1) if divisor is None: divisor = n_rows + if grad_divisor is None: + grad_divisor = divisor if block_size is None: block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: @@ -171,7 +174,7 @@ def triton_grpo_loss_forward_backward( grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits backward_kwargs = { "grad_logits_ptr": grad_logits, - "grad_losses": grad_output / divisor, + "grad_losses": grad_output / grad_divisor, "grad_logits_stride_0": grad_logits.stride(-2), "accumulate": accumulate, } diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 8472580f8..f36d4474e 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -21,11 +21,29 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - divisor = ( - kwargs[LanguageModelKwargs.num_documents_in_batch] - if self._config.normalize_by_documents - else self._get_label_count(kwargs) - ) + if self._config.normalize_by_documents: + # Match DeepSpeed exactly. DS has TWO 1/batch_size factors with different sources: + # - Loss reported uses /batch_size (via tokens_weights = 1/batch_size, see + # pipelinerl/finetune/rl/__init__.py:246). + # - Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes from + # `scale_wrt_gas=True` in engine.backward() (deepspeed/runtime/engine.py:1995-1996) + # and `tensor.div_(world_sz)` in reduce_scatter_coalesced + # (deepspeed/runtime/comm/coalesced_collectives.py:124). + # For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size = batch_size, + # so the gradient buffer effectively has factor 1/batch_size² while the loss metric has 1/batch_size. + # Fast-LLM cancels DS's /(gas × world_size) factor via `grad_output = data_parallel × grad_scale` + # (runner.py:318) interacting with FSDP's RS-AVG over data_parallel ranks (fsdp.py:396). + # So we need to apply the second 1/batch_size factor explicitly only to the gradient, + # keeping the loss metric matched to DS: + # loss divisor = num_documents (matches DS rl/loss) + # gradient divisor = num_documents² (matches DS grad_norm) + # Both are independent of TP/PP/SDP/DP parallelism and microbatching schedule. + num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch] + divisor = num_documents + grad_divisor = num_documents * num_documents + else: + divisor = self._get_label_count(kwargs) + grad_divisor = None # use divisor (default behavior) if self._config.policy_loss == "gspo": loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( logits, @@ -45,6 +63,7 @@ def _forward_backward( else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), divisor=divisor, + grad_divisor=grad_divisor, sdp_group=self._sdp_dim.group if self._sdp_active else None, ) else: @@ -71,6 +90,7 @@ def _forward_backward( else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) ), divisor=divisor, + grad_divisor=grad_divisor, ) if new_logprobs_mean is not None: @@ -198,10 +218,13 @@ def fused_grpo_loss_forward_backward( torch.Tensor | None ) = None, # (*batch,) — response-span length broadcast per token, 0 for non-response divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: if divisor is None: divisor = logits.shape[:-1].numel() - grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor + if grad_divisor is None: + grad_divisor = divisor + grad_output = None if grad_output is None else grad_output / grad_divisor * logits_scale_factor loss_mask = target >= 0 logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) @@ -272,6 +295,7 @@ def fused_gspo_loss_forward_backward( logits_scale_factor: float = 1.0, num_labels_in_seq: torch.Tensor | None = None, # for new_logprobs_mean metric divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) sdp_group: torch.distributed.ProcessGroup | None = None, # SDP group for cross-rank segment aggregation ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: """GSPO loss: sequence-level geometric-mean IS ratio clipping. @@ -282,10 +306,15 @@ def fused_gspo_loss_forward_backward( SDP correctness: scatter_add sums are all-reduced across sdp_group before computing R_s and A_s, ensuring correct segment-level ratios when tokens are split across ranks. + + The optional `grad_divisor` allows the gradient to use a different divisor than the loss + (e.g., to match DeepSpeed's metric where loss has /batch_size and gradient has /batch_size²). """ if divisor is None: divisor = float(logits.shape[0]) if logits.shape[0] > 0 else 1.0 - grad_output_scaled = None if grad_output is None else grad_output / divisor * logits_scale_factor + if grad_divisor is None: + grad_divisor = divisor + grad_output_scaled = None if grad_output is None else grad_output / grad_divisor * logits_scale_factor loss_mask = target >= 0 mask_float = loss_mask.float() @@ -340,6 +369,13 @@ def fused_gspo_loss_forward_backward( surr2 = R.clamp(1.0 - epsilon_low, 1.0 + epsilon_high) * A loss_per_seg = -torch.minimum(surr1, surr2) * tok_sum * valid.float() loss = loss_per_seg.sum() / divisor + # SDP correction: after SDP allreduce of lrn/adv/tok, both SDP ranks compute the IDENTICAL + # per-segment loss, so when LossDef.reduce sums across data_group (which includes SDP), the + # metric is double-counted by sdp_size. Divide here so each SDP rank reports loss/sdp_size, + # making the SUM-reduction match a non-SDP run. Gradient is unaffected (each SDP rank + # contributes gradient from its own LOCAL tokens, no double-counting in the gradient buffer). + if sdp_group is not None: + loss = loss / torch.distributed.get_world_size(sdp_group) # Step 7: Gradient — broadcast segment-level factors to token level if grad_output_scaled is not None and n_segs > 0: