gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature)#502
Open
bigximik wants to merge 9 commits intogrpo-metricsfrom
Open
gspo: GSPO loss + DeepSpeed parity fixes (loss/grad divisors, SDP, fp32_lm_head, docs_per_step, temperature)#502bigximik wants to merge 9 commits intogrpo-metricsfrom
bigximik wants to merge 9 commits intogrpo-metricsfrom
Conversation
Implements GSPO (geometric-mean sequence-level policy-gradient loss) as an alternative to the existing per-token GRPO clipping. Controlled via LanguageModelGRPOLossConfig.policy_loss = "gspo". Key changes: - data pipeline: expose per-token document_index when return_document_index=True - LanguageModelKwargs.document_index: new kwarg constant - LanguageModelLoss: store SDP dim for cross-rank segment aggregation - grpo.py: fused_gspo_loss_forward_backward with all_reduce(SUM) across SDP ranks before computing segment-level R_s and A_s; gradient derivation exploits tok_count cancellation so every token in a segment gets the same gradient factor R_s * clip_indicator_s - tests/layers/test_gspo_loss.py: 8 unit tests (single-segment, packed, ratio-1 equivalence, clipping, masking, SDP mock, gradient finite-diff, per-token metrics unchanged)
Adds ScheduleConfig.rollouts_per_step (default 0). When >0, TrainerConfig._from_dict
computes depth_first_micro_batches = rollouts_per_step // (batch_data_parallel ×
breadth_first_micro_batches) before sub-configs are created (and frozen).
Matches DeepSpeed gradient_accumulation_passes semantics for RL: with train_batch_size=1
each microbatch holds one rollout, so setting rollouts_per_step=1024 with data_parallel=8
gives depth_first_micro_batches=128 → exactly 1024 rollouts per optimizer step globally.
YAML usage:
schedule:
rollouts_per_step: 1024 # replaces manual depth_first_micro_batches
model:
distributed:
data_parallel: 8 # used for the division
- Rename rollouts_per_step → docs_per_step in ScheduleConfig; depth_first
is now determined at runtime rather than statically in _from_dict
- Add Schedule._depth_first_override and _eff_{depth_first,sequential,num_inputs}
properties so per-step schedules share the same config object as the runner
- Add Trainer._prefetch_to_doc_target: fetches microbatches one at a time,
all-reduces doc count per microbatch, stops when global total ≥ docs_per_step,
then resets num_documents_in_batch to the step total on all inputs
- Add Trainer._get_or_build_schedule: builds and caches per-N Schedule with
_depth_first_override=N//breadth_first_micro_batches
- Add normalize_by_documents flag to LanguageModelGRPOLossConfig; when True
both GRPO and GSPO paths divide by num_documents_in_batch instead of
num_labels_in_batch (matches DeepSpeed's per-rollout normalization)
- Add tests/layers/test_docs_per_step.py: 13 unit tests covering divisor
scaling, normalize_by_documents layer routing, Schedule._eff_* properties,
and _prefetch_to_doc_target accumulation logic
Add temperature field to LanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probs are computed at the same temperature as the stored old log-probs, so the IS ratio starts near 1.0 instead of ~1.08. Implementation: _effective_logits_scale = logits_scale_factor / temperature, substituted for logits_scale_factor at all three callsites in _forward_backward (GRPO path, GSPO path, _register_pg_metrics). Default temperature=1.0 preserves existing behaviour exactly.
Add fp32_lm_head to LanguageModelHeadConfig. When enabled, input hidden states and output_weights are cast to float32 before the lm_head linear, producing FP32 logits. This matches vLLM's bf16_last_layer_fp32 quantization (pipelinerl/vllm_quantization.py) and the DeepSpeed trainer's apply_fp32_lm_head() patch, so new_logprobs and old_logprobs are computed at the same numerical precision and the IS ratio starts near 1.0 at init. The gradient flowing back through the linear is cast to the original input dtype (bf16) before returning, keeping the transformer backward pass in its native dtype.
…accumulation Detaching the FP32 weight copy (requires_grad=False) prevents output_parallel_linear_backward from trying to write to a non-existent grad_buffer on the copy. Weight grad is then computed explicitly from the FP32 matmul and accumulated into the original BF16 param's grad_buffer via accumulate_gradient, restoring the correct FSDP gradient contract.
When normalize_by_documents=true, fast-LLM's reported grad_norm was ~1024×
larger than DeepSpeed's for the equivalent loss, causing the default
gradient_norm_clipping=0.3 to over-clip by ~500× and making training ~10
reward points slower than DS GSPO at the same step count. The lm_head_loss
metric was also off — 1024× smaller than DS's rl/loss in the previous
divisor=num_documents² formulation, then 2× too large from SDP doubling.
Root cause analysis
-------------------
DeepSpeed has TWO 1/batch_size factors with different sources:
1. Loss reported (rl/loss) uses /batch_size via tokens_weights = 1/batch_size
(pipelinerl/finetune/rl/__init__.py:246). The reported `rl/loss = -1.7`
value is the raw policy_loss_total, divided once by batch_size.
2. Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes
from `scale_wrt_gas=True` in engine.backward()
(deepspeed/runtime/engine.py:1995-1996) and `tensor.div_(world_sz)` in
reduce_scatter_coalesced (deepspeed/runtime/comm/coalesced_collectives.py:124).
For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size
= batch_size, so DS's effective gradient buffer factor is 1/batch_size² while
the loss metric factor is 1/batch_size. Loss and gradient have asymmetric
scaling.
Fast-LLM's existing implementation used a single `divisor` for both loss and
gradient. Worse, the data_parallel × grad_scale factor in grad_output
(runner.py:318) cancels with FSDP's RS-AVG /world_size, structurally removing
DS's /(gas × world_size) factor from the gradient. So fast-LLM's gradient
buffer ended up at 1/batch_size while DS's was at 1/batch_size² — a
~batch_size = 1024× mismatch.
Additionally, GSPO's SDP allreduce of lrn_sum/adv_sum/tok_sum makes both SDP
ranks compute IDENTICAL per-segment loss values. When LossDef.reduce sums
over the data_group (which includes SDP ranks), the loss metric is
double-counted by sdp_size. The gradient buffer is NOT double-counted —
each SDP rank contributes gradient from its own LOCAL tokens, with different
contributions for different tokens of the same segment.
Fixes
-----
1. Add a `grad_divisor` parameter to `fused_gspo_loss_forward_backward`,
`fused_grpo_loss_forward_backward`, and `triton_grpo_loss_forward_backward`,
defaulting to `divisor` (existing behavior). Allows the gradient to use a
different divisor than the loss.
2. In `LanguageModelGRPOLoss._forward_backward`, when normalize_by_documents
is True, set:
loss divisor = num_documents_in_batch (matches DS rl/loss)
gradient divisor = num_documents_in_batch² (matches DS grad_norm)
This is independent of TP/PP/SDP/DP parallelism and microbatching schedule
because batch_size is invariant under all of these.
3. In the GSPO path, divide the loss by sdp_size when sdp_group is active
(`fused_gspo_loss_forward_backward`). This pre-cancels the SDP doubling
that LossDef.reduce's SUM over data_group introduces. The gradient is
unaffected — different SDP ranks naturally contribute gradient from
different LOCAL token positions, no double-counting at any layer.
Verification
------------
Tested on 7B math run with 4 nodes, GSPO, gradient_norm_clipping=0.3:
Before fix | After fix | DS GSPO reference
------------------- | ------------------ | ------------------
step 1 grad_norm=141| step 1 grad_norm=0.135 | step 1 grad_norm=0.145
step 1 lm_head_loss | step 1 lm_head_loss | step 1 rl/loss
= -13.7 | ~ -1.7 (sign varies | = -1.7
| per data sample) |
clip_coeff=0.002 | clip_coeff=1.000 | no clipping at step 1
newlp at step 50 | newlp at step 50 | newlp at step 50
trapped at -0.17 | = -0.103 | = -0.105
newlp trajectory tracks DS step-by-step: step 1 within 3%, step 50 within 2%.
Both systems show grad_norm spikes at the same training phase (steps 14-20)
during warmup ramp-up — DS step 16 grad_norm=6.365 vs Fast-LLM 6.093.
Files changed
-------------
- fast_llm/layers/language_model/loss/grpo.py:
- LanguageModelGRPOLoss._forward_backward: split divisor and grad_divisor
based on normalize_by_documents flag, with detailed comments referencing
the corresponding lines in DeepSpeed and PipelineRL.
- fused_gspo_loss_forward_backward: add grad_divisor parameter; divide loss
by sdp_size when sdp_group is active.
- fused_grpo_loss_forward_backward: add grad_divisor parameter.
- fast_llm/functional/triton/grpo_loss.py:
- triton_grpo_loss_forward_backward: add grad_divisor parameter.
# Conflicts: # fast_llm/layers/language_model/loss/config.py # fast_llm/layers/language_model/loss/grpo.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds GSPO loss to fast-LLM along with a suite of supporting fixes that together achieve full metric and training-trajectory parity with DeepSpeed's GRPO/GSPO implementation. Targets the
grpo-metricsbranch. Six logical units:1. GSPO loss (sequence-level IS-ratio clipping)
Implements GSPO as an alternative policy-gradient loss alongside the existing per-token GRPO clipping. Controlled via
LanguageModelGRPOLossConfig.policy_loss = "gspo".fused_gspo_loss_forward_backwardkernel: computes per-segment geometric-mean log-ratioR_s, clips at[1−ε_low, 1+ε_high], and appliesR_s × A_sas a uniform per-token gradient within each segment. Anall_reduce(SUM)over sequence-data-parallel ranks aggregates(lrn_sum, adv_sum, tok_count)before clipping so the ratio is correct under sequence parallelism.document_indexdata field andLanguageModelKwargs.document_indexkwarg constant to route per-token segment membership through the data pipeline.tests/layers/test_gspo_loss.py(single-segment, packed sequences, ratio=1 equivalence, clipping, masking, SDP mock, gradient finite-diff, independence from per-token metrics).2. Dynamic
docs_per_stepaccumulationReplaces static
depth_first_micro_batcheswith a runtime document-count target — matching DeepSpeed'sgradient_accumulation_passessemantics for RL (where each microbatch holds one rollout).ScheduleConfig.docs_per_step: when >0,Trainer._prefetch_to_doc_targetfetches microbatches one at a time, all-reduces the per-microbatch document count, and stops once the global total ≥docs_per_step. The final step total is broadcast to all inputs so the normalisation denominator is consistent.Trainer._get_or_build_schedulebuilds and caches a per-NSchedulewith_depth_first_override = N // breadth_first_micro_batches, so the existing schedule machinery is reused without changes to the runner.Schedule._eff_{depth_first,sequential,num_inputs}properties expose the effective values for a given override.tests/layers/test_docs_per_step.py.3.
normalize_by_documentsAdds a
normalize_by_documentsflag toLanguageModelGRPOLossConfig. WhenTrue, both the GRPO and GSPO paths divide the loss bynum_documents_in_batch(the step-level rollout count) rather than the token count. Matches DeepSpeed's normalization wheretokens_weights = 1 / batch_size.4. Temperature scaling for IS ratio parity
Adds a
temperaturefield toLanguageModelGRPOLossConfig. When set to match the actor's sampling temperature (e.g. 0.7), new log-probabilities are computed at the same temperature as the stored old log-probabilities from vLLM, so the IS ratio starts near 1.0 at step 0 instead of ~1.08. Implementation:_effective_logits_scale = logits_scale_factor / temperature, substituted at all three call-sites in_forward_backward. Defaulttemperature=1.0preserves existing behaviour exactly.5.
fp32_lm_headprecision fix (matches vLLM'sbf16_last_layer_fp32)Adds a
fp32_lm_headflag (defaultFalse) onLanguageModelHeadConfig. WhenTrue, the LM head's logits computation upcasts both input and weight to FP32 before the linear projection, matching vLLM'sbf16_last_layer_fp32quantization. This ensures the trainer computes log-probabilities at the same numerical precision as the actor's sampling, sonew_logprobs ≈ old_logprobsat step 0 (IS ratio at training start ≈ 1.0, not artificially inflated by precision mismatch).d8cb9ef5: introduces the flag, upcasts input/weight, casts back to BF16 before downstream consumption.0f90f20b: fixes the gradient flow whenfp32_lm_head=True. The detached FP32 weight copy hasrequires_grad=False, which makesoutput_parallel_linear_backwardskip writing to the original weight'sgrad_buffer. We restore the FSDP gradient contract by computinggrad_weight = grad.t() @ saved_inputexplicitly and accumulating into the BF16 param'sgrad_bufferviaaccumulate_gradient.6. Decoupled loss/gradient divisors and SDP loss double-counting fix
Even with
normalize_by_documents=true, fast-LLM's reportedgrad_normwas ~1024× larger than DeepSpeed's, causing the defaultgradient_norm_clipping=0.3to over-clip by ~500× and making training ~10 reward points slower than DS GSPO at the same step count. Two issues, fixed in commit557a3c4c:Asymmetric loss/gradient scaling in DS:
/batch_sizeonce (viatokens_weights = 1/batch_size,pipelinerl/finetune/rl/__init__.py:246)./(gas × world_size)factor fromscale_wrt_gas=Trueinengine.backward()(deepspeed/runtime/engine.py:1995-1996) andtensor.div_(world_sz)inreduce_scatter_coalesced(deepspeed/runtime/comm/coalesced_collectives.py:124).samples_per_microbatch=1(PipelineRL standard),gas × world_size = batch_size, so the gradient buffer effectively has1/batch_size²while the loss metric has1/batch_size.Fast-LLM cancels DS's
/(gas × world_size)factor structurally viagrad_output = data_parallel × grad_scale(runner.py:318) interacting with FSDP's RS-AVG overdata_parallelranks (fsdp.py:396). So we need to apply the second1/batch_sizefactor explicitly only to the gradient — keeping the loss metric matched to DS.Fix: add a
grad_divisorparameter tofused_gspo_loss_forward_backward,fused_grpo_loss_forward_backward, andtriton_grpo_loss_forward_backward. Whennormalize_by_documents=true:num_documents_in_batch(matches DSrl/loss)num_documents_in_batch²(matches DSgrad_norm)Independent of TP/PP/SDP/DP parallelism and microbatching schedule, because
batch_sizeis invariant under all of them.SDP loss double-counting:
After the SDP allreduce of
lrn_sum/adv_sum/tok_suminfused_gspo_loss_forward_backward, both SDP ranks compute IDENTICAL per-segment loss values. WhenLossDef.reduceSUMs acrossdata_group(which includes SDP ranks), the loss metric is double-counted bysdp_size. The gradient is NOT double-counted — each SDP rank contributes gradient from its own LOCAL tokens, with different contributions for different tokens of the same segment.Fix: divide loss by
sdp_sizewhensdp_groupis active. Gradient unaffected.Verification
End-to-end 7B math run on 4 nodes, GSPO,
gradient_norm_clipping=0.3(default),normalize_by_documents=true,fp32_lm_head=true,temperature=0.7:grad_normlm_head_lossclip_coeffnewlpnewlp trajectory tracks DS step-by-step. Both systems show same gradient-spike pattern during warmup ramp-up at steps 14-20 (DS step 16 grad_norm=6.365, fast-LLM step 15=9.005). Match within data variance.
Test plan
pytest tests/layers/test_gspo_loss.py— GSPO unit tests passpytest tests/layers/test_docs_per_step.py— docs_per_step unit tests passpytest tests/layers/test_lm_losses.py— existing GRPO loss + per-token metrics tests unaffected (the metrics tests previously intest_grpo_metrics.pymoved into this file on the base branch)docs_per_step=1024,temperature=0.7,normalize_by_documents=true,fp32_lm_head=true, defaultgradient_norm_clipping=0.3) — grad_norm matches DS at step 1, training trajectory matches DS step-by-step through step 50+ (ongoing run validates through step ~410).