diff --git a/fast_llm/data/document/config.py b/fast_llm/data/document/config.py index 352311b51..ad6a7305f 100644 --- a/fast_llm/data/document/config.py +++ b/fast_llm/data/document/config.py @@ -79,6 +79,7 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig): use_preference_spans: bool = Field(default=False) use_grpo_data: bool = Field(default=False) return_label_counts: bool = Field(default=False) + return_document_index: bool = Field(default=False) def _validate(self) -> None: super()._validate() diff --git a/fast_llm/data/document/language_model.py b/fast_llm/data/document/language_model.py index 7821b81c5..8dab70efb 100644 --- a/fast_llm/data/document/language_model.py +++ b/fast_llm/data/document/language_model.py @@ -35,6 +35,7 @@ class LanguageModelTargetInput(ModelInput): advantages: torch.Tensor | None = None old_log_probabilities: torch.Tensor | None = None label_counts: torch.Tensor | None = None + document_index: torch.Tensor | None = None num_labels: int | None = None num_labels_in_batch: int | None = None @@ -84,6 +85,7 @@ def to_kwargs(self) -> dict[str, typing.Any]: LanguageModelKwargs.advantages: [target.advantages for target in self.targets], LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets], LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets], + LanguageModelKwargs.document_index: [target.document_index for target in self.targets], LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets], } if self.image_patches is not None: @@ -177,7 +179,11 @@ def _set_target_inputs( document_begin += length mask = labels >= 0 - label_counts = self._get_label_counts(mask) if config.return_label_counts else None + label_counts, document_index = ( + self._get_label_counts(mask) + if config.return_label_counts or config.return_document_index + else (None, None) + ) for input_index, model_input in enumerate(model_inputs): label_end = model_input.sequence_k_dim.size + prediction_distance @@ -188,6 +194,7 @@ def _set_target_inputs( tokens=labels[label_begin:label_end].clone(), mask=mask[label_begin:label_end] if config.return_prediction_mask else None, label_counts=label_counts[label_begin:label_end] if config.return_label_counts else None, + document_index=document_index[label_begin:label_end] if config.return_document_index else None, # Set value for the first input only so `share_batch_data` generated the correct sum. # TODO: ====== Make optional? num_labels=( @@ -202,7 +209,7 @@ def _set_target_inputs( model_input.targets.append(target_input) - def _get_label_counts(self, mask: torch.Tensor): + def _get_label_counts(self, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: # Count the number of non-masked labels in each document through cumulative sums. mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)]) length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0) @@ -214,4 +221,4 @@ def _get_label_counts(self, mask: torch.Tensor): document_index = torch.searchsorted( length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right" ) - return labels_per_document[document_index] + return labels_per_document[document_index], document_index diff --git a/fast_llm/engine/schedule/config.py b/fast_llm/engine/schedule/config.py index 29720b90b..2920c1334 100644 --- a/fast_llm/engine/schedule/config.py +++ b/fast_llm/engine/schedule/config.py @@ -21,6 +21,16 @@ class ScheduleConfig(Config): hint=FieldHint.core, valid=check_field(Assert.gt, 0), ) + docs_per_step: int = Field( + default=0, + desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. " + "When >0, each training step dynamically accumulates microbatches until the globally all-reduced " + "document count reaches this value, then triggers the optimizer step. " + "depth_first_micro_batches is ignored when this is set. " + "0 = use depth_first_micro_batches as-is (fixed microbatch count per step).", + hint=FieldHint.feature, + valid=check_field(Assert.geq, 0), + ) breadth_first_micro_batches: int = Field( default=1, desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.", diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index b2e212946..128b95e8e 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -320,7 +320,8 @@ def _preprocess_data( if context.schedule.phase.is_training else None ) - model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)] + n_micro_batches = context.schedule._eff_sequential_micro_batches + model_inputs = [next(data_iterator) for _ in range(n_micro_batches)] model_inputs[0][0].share_batch_data( [model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed ) @@ -336,7 +337,7 @@ def _preprocess_data( extra_kwargs={ "grad_output": grad_output, "micro_batch": micro_batch, - "num_micro_batches": self._config.sequential_micro_batches, + "num_micro_batches": n_micro_batches, "micro_batch_splits": self._config.micro_batch_splits, }, ) diff --git a/fast_llm/engine/schedule/schedule.py b/fast_llm/engine/schedule/schedule.py index 6f7bf1d95..845b5df82 100644 --- a/fast_llm/engine/schedule/schedule.py +++ b/fast_llm/engine/schedule/schedule.py @@ -115,15 +115,17 @@ def __init__( batch_meta: list[ModelInput], distributed_config: DistributedConfig, phase: PhaseType, + _depth_first_override: int | None = None, ): super().__init__(config) + self._depth_first_override = _depth_first_override self._multi_stage = multi_stage self._distributed_config = distributed_config self._num_stages = len(self._multi_stage.stages) self._phase = phase self._is_training = self._phase.is_training - if self._config.num_inputs < self._distributed_config.pipeline_parallel: + if self._eff_num_inputs < self._distributed_config.pipeline_parallel: warnings.warn("Not enough input to achieve true pipeline parallelism.") # Setup the activation metas. @@ -155,9 +157,25 @@ def __init__( def phase(self) -> PhaseType: return self._phase + @property + def _eff_depth_first(self) -> int: + return ( + self._depth_first_override + if self._depth_first_override is not None + else self._config.depth_first_micro_batches + ) + + @property + def _eff_sequential_micro_batches(self) -> int: + return self._eff_depth_first * self._config.breadth_first_micro_batches + + @property + def _eff_num_inputs(self) -> int: + return self._eff_sequential_micro_batches * self._config.micro_batch_splits + @property def samples_per_batch(self) -> int: - return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel + return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]: return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank]) @@ -189,7 +207,7 @@ def _create_index(self) -> None: Assert.in_range( step.index, 0, - self._config.num_inputs, + self._eff_num_inputs, ) Assert.incl(step.type_, (StepType.forward, StepType.backward)) step.global_index = i @@ -205,7 +223,7 @@ def _create_index(self) -> None: Assert.custom(all, self._device_steps) # Consistency checks step_map = self._step_map.copy() - for data_index in range(self._config.num_inputs): + for data_index in range(self._eff_num_inputs): for type_ in (StepType.forward, StepType.backward): for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages): assert ( @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]: first_grad_stage += 1 else: first_grad_stage = self._num_stages - for depth_first_micro_batch in range(self._config.depth_first_micro_batches): + for depth_first_micro_batch in range(self._eff_depth_first): for stage in range(self._num_stages): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in range(self._config.micro_batch_splits): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]: for stage in reversed(range(first_grad_stage, self._num_stages)): for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches): for micro_batch_split in reversed(range(self._config.micro_batch_splits)): - micro_batch = ( - breadth_first_micro_batch * self._config.depth_first_micro_batches - + depth_first_micro_batch - ) + micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch steps.append( Step( stage=stage, diff --git a/fast_llm/engine/training/trainer.py b/fast_llm/engine/training/trainer.py index 1ed18c449..cc37e92c2 100644 --- a/fast_llm/engine/training/trainer.py +++ b/fast_llm/engine/training/trainer.py @@ -115,10 +115,13 @@ def setup(self, distributed: Distributed, run: Run) -> None: preprocessing_config = self._multi_stage.get_preprocessing_config( PhaseType.training, self._config.schedule.micro_batch_splits ) + self._preprocessing_config = preprocessing_config + self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size) + self._schedule_cache: dict[int, Schedule] = {} self._schedule = Schedule( config=self._config.schedule, multi_stage=self._multi_stage, - batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size), + batch_meta=self._single_mb_meta, distributed_config=self._config.model.distributed, phase=PhaseType.training, ) @@ -140,6 +143,41 @@ def setup(self, distributed: Distributed, run: Run) -> None: self._is_setup = True + def _get_or_build_schedule(self, n_microbatches: int) -> Schedule: + if n_microbatches not in self._schedule_cache: + bfmb = self._config.schedule.breadth_first_micro_batches + depth_first = n_microbatches // bfmb + self._schedule_cache[n_microbatches] = Schedule( + config=self._config.schedule, + multi_stage=self._multi_stage, + batch_meta=self._single_mb_meta, + distributed_config=self._config.model.distributed, + phase=PhaseType.training, + _depth_first_override=depth_first, + ) + return self._schedule_cache[n_microbatches] + + def _prefetch_to_doc_target(self, data_iterator) -> list: + target = self._config.schedule.docs_per_step + bfmb = self._config.schedule.breadth_first_micro_batches + buffer = [] + total_docs = 0 + while total_docs < target: + mb = next(data_iterator) + mb[0].share_batch_data(mb, self._distributed) + total_docs += mb[0].num_documents_in_batch + buffer.append(mb) + Assert.eq( + len(buffer) % bfmb, + 0, + msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}", + ) + # Reset num_documents_in_batch to the step total on all microbatches + for mb in buffer: + for mi in mb: + mi.num_documents_in_batch = total_docs + return buffer + @abc.abstractmethod def _get_data(self) -> Data: pass @@ -220,12 +258,22 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]: # TODO: Data loader hates getting all micro-batches at once. # (Also preprocessing adds overhead) - reduced_losses, update_successful, train_metrics = self._runner.run_step( - train_iterator, - self._schedule, - iteration=self._completed_steps, - return_metrics=is_logging, - ) + if self._config.schedule.docs_per_step > 0: + buffer = self._prefetch_to_doc_target(train_iterator) + step_schedule = self._get_or_build_schedule(len(buffer)) + reduced_losses, update_successful, train_metrics = self._runner.run_step( + iter(buffer), + step_schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) + else: + reduced_losses, update_successful, train_metrics = self._runner.run_step( + train_iterator, + self._schedule, + iteration=self._completed_steps, + return_metrics=is_logging, + ) # Advanced, skipped, and Nan iterations. if update_successful: diff --git a/fast_llm/functional/triton/grpo_loss.py b/fast_llm/functional/triton/grpo_loss.py index 39d832ccd..709bbc73c 100644 --- a/fast_llm/functional/triton/grpo_loss.py +++ b/fast_llm/functional/triton/grpo_loss.py @@ -137,6 +137,7 @@ def triton_grpo_loss_forward_backward( logits_scale_factor: float = 1.0, num_labels_in_seq: torch.Tensor | None = None, divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) block_size: int | None = None, num_warps: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: @@ -148,6 +149,8 @@ def triton_grpo_loss_forward_backward( n_cols = logits.size(-1) if divisor is None: divisor = n_rows + if grad_divisor is None: + grad_divisor = divisor if block_size is None: block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: @@ -171,7 +174,7 @@ def triton_grpo_loss_forward_backward( grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits backward_kwargs = { "grad_logits_ptr": grad_logits, - "grad_losses": grad_output / divisor, + "grad_losses": grad_output / grad_divisor, "grad_logits_stride_0": grad_logits.stride(-2), "accumulate": accumulate, } diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 4a8efdab6..0d48e92cb 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -25,6 +25,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs): sample_map = "sample_map" embedding_map = "embedding_map" num_documents_in_batch = "num_documents_in_batch" + document_index = "document_index" # TODO: These are generic phase = "phase" loss_mask = "loss_mask" @@ -119,6 +120,13 @@ class LanguageModelHeadConfig(BlockConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + fp32_lm_head: bool = Field( + default=False, + desc="Upcast input and weight to float32 before the lm_head linear. " + "Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs " + "are computed at the same numerical precision, keeping the IS ratio near 1 at init.", + hint=FieldHint.feature, + ) prediction_heads: int = Field( default=1, desc="Prediction heads.", diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 95be18035..31addb34c 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -22,7 +22,7 @@ ) from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.tensor import TensorMeta +from fast_llm.tensor import TensorMeta, accumulate_gradient from fast_llm.utils import Assert, safe_merge_dicts logger = logging.getLogger(__name__) @@ -242,9 +242,17 @@ def _logits_loss_forward_backward_partial( split_index: int = 0, return_logits: bool = False, ) -> tuple[torch.Tensor | None, torch.Tensor | None]: + if self._config.fp32_lm_head: + input_dtype = input_.dtype + input_ = input_.to(torch.float32) + # detach → requires_grad=False → output_parallel_linear_backward skips weight grad + weight = self.output_weights.detach().to(torch.float32) + else: + weight = self.output_weights + logits, context = output_parallel_linear_forward( input_=input_, - weight=self.output_weights, + weight=weight, bias=None, group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, @@ -273,9 +281,23 @@ def _logits_loss_forward_backward_partial( if loss_value is not None: losses_.append(loss_value.detach()) - return sum(losses_) if losses_ else None, ( - output_parallel_linear_backward(grad, context) if self.training else None - ) + if not self.training or grad is None: + return sum(losses_) if losses_ else None, None + + input_grad = output_parallel_linear_backward(grad, context) + if self._config.fp32_lm_head: + # Weight grad was skipped because weight.requires_grad=False; accumulate manually. + # context: (input_, weight, bias, group, sequence_parallel, ...) + saved_input = context[0] + if context[4]: # sequence_parallel + from fast_llm.core.ops import gather_op + + saved_input = gather_op(saved_input, context[3], dim=0) + grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2)) + accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype)) + input_grad = input_grad.to(input_dtype) + + return sum(losses_) if losses_ else None, input_grad def get_loss_definitions(self) -> list[LossDef]: return [ diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 70cf8806a..1a1b55ceb 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -205,6 +205,11 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): _abstract: typing.ClassVar[bool] = False + policy_loss: str = Field( + default="grpo", + desc="Policy loss algorithm: 'grpo' (per-token IS ratio clipping) or 'gspo' (sequence-level geometric-mean clipping).", + valid=check_field(Assert.incl, ["grpo", "gspo"]), + ) epsilon_low: float = Field(default=0.2, desc="Lower clip parameter for ratio of log probs") epsilon_high: float = Field(default=0.2, desc="Upper clip parameter for ratio of log probs") use_triton: bool | None = Field( @@ -222,6 +227,21 @@ class LanguageModelGRPOLossConfig(LanguageModelLossConfig): ), hint=FieldHint.feature, ) + normalize_by_documents: bool = Field( + default=False, + desc="Normalize the policy-gradient loss by the number of documents (rollouts) in the step " + "rather than the number of tokens. Matches DeepSpeed's normalization where each token's " + "loss is divided by config.batch_size (total rollout count). " + "Set to True when using docs_per_step for full DS parity.", + hint=FieldHint.feature, + ) + temperature: float = Field( + default=1.0, + desc="Temperature applied to logits before computing new log-probabilities. " + "Set to match the sampling temperature used by the actor (e.g. 0.7) so that " + "new and old log-probs are in the same scale and the IS ratio starts near 1.", + valid=check_field(Assert.gt, 0), + ) @property def loss_class(self) -> "type[LanguageModelGRPOLoss]": diff --git a/fast_llm/layers/language_model/loss/grpo.py b/fast_llm/layers/language_model/loss/grpo.py index 4bbaeb581..d2dbb8f6c 100644 --- a/fast_llm/layers/language_model/loss/grpo.py +++ b/fast_llm/layers/language_model/loss/grpo.py @@ -74,30 +74,77 @@ def _forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - if TritonConfig.enabled(logits.device, self._config.use_triton): - from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward - - fn = triton_grpo_loss_forward_backward + if self._config.normalize_by_documents: + # Match DeepSpeed exactly. DS has TWO 1/batch_size factors with different sources: + # - Loss reported uses /batch_size (via tokens_weights = 1/batch_size, see + # pipelinerl/finetune/rl/__init__.py:246). + # - Gradient buffer uses an ADDITIONAL /(gas × world_size) factor that comes from + # `scale_wrt_gas=True` in engine.backward() (deepspeed/runtime/engine.py:1995-1996) + # and `tensor.div_(world_sz)` in reduce_scatter_coalesced + # (deepspeed/runtime/comm/coalesced_collectives.py:124). + # For DS with samples_per_microbatch=1 (PipelineRL standard), gas × world_size = batch_size, + # so the gradient buffer effectively has factor 1/batch_size² while the loss metric has 1/batch_size. + # Fast-LLM cancels DS's /(gas × world_size) factor via `grad_output = data_parallel × grad_scale` + # (runner.py:318) interacting with FSDP's RS-AVG over data_parallel ranks (fsdp.py:396). + # So we need to apply the second 1/batch_size factor explicitly only to the gradient, + # keeping the loss metric matched to DS: + # loss divisor = num_documents (matches DS rl/loss) + # gradient divisor = num_documents² (matches DS grad_norm) + # Both are independent of TP/PP/SDP/DP parallelism and microbatching schedule. + num_documents = kwargs[LanguageModelKwargs.num_documents_in_batch] + divisor = num_documents + grad_divisor = num_documents * num_documents else: - fn = fused_grpo_loss_forward_backward - loss, grad, new_logprobs_mean = fn( - logits, - self._get_labels(kwargs, split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), - self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), - grad_logits=grad_logits, - grad_output=self._get_grad_output(kwargs), - group=self._parallel_dim.group if self._vocab_parallel else None, - epsilon_low=self._config.epsilon_low, - epsilon_high=self._config.epsilon_high, - logits_scale_factor=self._logits_scale_factor, - num_labels_in_seq=( - None - if losses is None - else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) - ), - divisor=self._get_label_count(kwargs), - ) + divisor = self._get_label_count(kwargs) + grad_divisor = None # use divisor (default behavior) + if self._config.policy_loss == "gspo": + loss, grad, new_logprobs_mean = fused_gspo_loss_forward_backward( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + self._prepare_target(kwargs[LanguageModelKwargs.document_index], split_index), + grad_logits=grad_logits, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._effective_logits_scale, + num_labels_in_seq=( + None + if losses is None + else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) + ), + divisor=divisor, + grad_divisor=grad_divisor, + sdp_group=self._sdp_dim.group if self._sdp_active else None, + ) + else: + if TritonConfig.enabled(logits.device, self._config.use_triton): + from fast_llm.functional.triton.grpo_loss import triton_grpo_loss_forward_backward + + fn = triton_grpo_loss_forward_backward + else: + fn = fused_grpo_loss_forward_backward + loss, grad, new_logprobs_mean = fn( + logits, + self._get_labels(kwargs, split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.advantages], split_index), + self._prepare_target(kwargs[LanguageModelLossKwargs.old_log_probabilities], split_index), + grad_logits=grad_logits, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + epsilon_low=self._config.epsilon_low, + epsilon_high=self._config.epsilon_high, + logits_scale_factor=self._effective_logits_scale, + num_labels_in_seq=( + None + if losses is None + else self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index) + ), + divisor=divisor, + grad_divisor=grad_divisor, + ) if new_logprobs_mean is not None: new_logprobs_mean = new_logprobs_mean / kwargs[LanguageModelKwargs.num_documents_in_batch] @@ -126,7 +173,7 @@ def _register_extra_metrics( self._prepare_target(kwargs[LanguageModelLossKwargs.label_counts], split_index), self._config.epsilon_low, self._config.epsilon_high, - self._logits_scale_factor, + self._effective_logits_scale, group=self._parallel_dim.group if self._vocab_parallel else None, compute_entropy=self._config.metrics == GRPOMetricsLevel.with_entropy, ) @@ -190,7 +237,14 @@ def get_loss_definitions(self) -> list[LossDef]: def get_preprocessing_config( self, ) -> dict[str, typing.Any]: - return {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + config = {"use_grpo_data": True, "return_label_counts": True, "return_document_count": True} + if self._config.policy_loss == "gspo": + config["return_document_index"] = True + return config + + @functools.cached_property + def _effective_logits_scale(self) -> float: + return self._logits_scale_factor / self._config.temperature @functools.cached_property def _logprob_metric_name(self) -> str: @@ -268,10 +322,13 @@ def fused_grpo_loss_forward_backward( torch.Tensor | None ) = None, # (*batch,) — response-span length broadcast per token, 0 for non-response divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]: if divisor is None: divisor = logits.shape[:-1].numel() - grad_output = None if grad_output is None else grad_output / divisor * logits_scale_factor + if grad_divisor is None: + grad_divisor = divisor + grad_output = None if grad_output is None else grad_output / grad_divisor * logits_scale_factor loss_mask = target >= 0 logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) @@ -326,3 +383,129 @@ def fused_grpo_loss_forward_backward( grad_logits.add_(grad) return loss, grad_logits, new_logprobs_mean + + +def fused_gspo_loss_forward_backward( + logits: torch.Tensor, # (n_tokens, vocab_local) + target: torch.Tensor, # (n_tokens,) + advantages: torch.Tensor, # (n_tokens,) + old_log_probabilities: torch.Tensor, # (n_tokens,) + document_index: torch.Tensor, # (n_tokens,) int64 — segment ID per token + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, + group: torch.distributed.ProcessGroup | None = None, # TP vocab group + epsilon_low: float = 0.2, + epsilon_high: float = 0.2, + logits_scale_factor: float = 1.0, + num_labels_in_seq: torch.Tensor | None = None, # for new_logprobs_mean metric + divisor: float | None = None, + grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor) + sdp_group: torch.distributed.ProcessGroup | None = None, # SDP group for cross-rank segment aggregation +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """GSPO loss: sequence-level geometric-mean IS ratio clipping. + + Each segment s gets ratio R_s = exp(mean_t(log(p_new_t/p_old_t))), clipped as a unit. + Loss = -sum_s tok_count_s * min(R_s*A_s, clip(R_s)*A_s) / divisor. + Gradient: tok_count_s cancels, so each token in segment s gets the same gradient factor R_s. + + SDP correctness: scatter_add sums are all-reduced across sdp_group before computing R_s and A_s, + ensuring correct segment-level ratios when tokens are split across ranks. + + The optional `grad_divisor` allows the gradient to use a different divisor than the loss + (e.g., to match DeepSpeed's metric where loss has /batch_size and gradient has /batch_size²). + """ + if divisor is None: + divisor = float(logits.shape[0]) if logits.shape[0] > 0 else 1.0 + if grad_divisor is None: + grad_divisor = divisor + grad_output_scaled = None if grad_output is None else grad_output / grad_divisor * logits_scale_factor + + loss_mask = target >= 0 + mask_float = loss_mask.float() + + # Step 1: Softmax + log probs (same as GRPO) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, target_masked, target_mask = fused_predicted_logits_from_labels( + logits_norm, target, loss_mask, group + ) + new_log_probs = predicted_logits - sum_exp_logits.log() + log_ratio = (new_log_probs - old_log_probabilities).float() + + # new_logprobs_mean: local partial sum (aggregated across SDP via LossDef.reduce, same as GRPO) + new_logprobs_mean = ( + None if num_labels_in_seq is None else (new_log_probs * mask_float / num_labels_in_seq.clamp(min=1)).sum() + ) + + # Step 2: Determine global n_segs (max doc index + 1, all-reduced across SDP) + n_segs_local = int(document_index.max().item()) + 1 if document_index.numel() > 0 else 0 + if sdp_group is not None: + n_segs_t = torch.tensor(n_segs_local, device=logits.device, dtype=torch.int64) + torch.distributed.all_reduce(n_segs_t, op=torch.distributed.ReduceOp.MAX, group=sdp_group) + n_segs = int(n_segs_t.item()) + else: + n_segs = n_segs_local + + # Step 3: Per-segment scatter_add (local contributions only) + lrn_sum = log_ratio.new_zeros(n_segs) # sum of log-ratios per segment + adv_sum = advantages.new_zeros(n_segs).float() # sum of advantages per segment + tok_sum = log_ratio.new_zeros(n_segs) # token count per segment + + if loss_mask.any() and n_segs > 0: + masked_doc_ids = document_index[loss_mask].long() + lrn_sum.index_add_(0, masked_doc_ids, log_ratio[loss_mask]) + adv_sum.index_add_(0, masked_doc_ids, advantages[loss_mask].float()) + tok_sum.index_add_(0, masked_doc_ids, torch.ones(masked_doc_ids.numel(), device=logits.device)) + + # Step 4: SDP all-reduce so every rank has global per-segment sums + if sdp_group is not None and n_segs > 0: + torch.distributed.all_reduce(lrn_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(adv_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + torch.distributed.all_reduce(tok_sum, op=torch.distributed.ReduceOp.SUM, group=sdp_group) + + # Step 5: Segment-level ratio R_s and advantage A_s + valid = tok_sum > 0 + seg_denom = tok_sum.clamp(min=1e-6) + R = (lrn_sum / seg_denom).exp() # geometric mean IS ratio per segment + A = (adv_sum / seg_denom).detach() # mean advantage per segment (no gradient through A) + + # Step 6: GSPO loss — length-proportional weight tok_sum cancels with 1/N in gradient + surr1 = R * A + surr2 = R.clamp(1.0 - epsilon_low, 1.0 + epsilon_high) * A + loss_per_seg = -torch.minimum(surr1, surr2) * tok_sum * valid.float() + loss = loss_per_seg.sum() / divisor + # SDP correction: after SDP allreduce of lrn/adv/tok, both SDP ranks compute the IDENTICAL + # per-segment loss, so when LossDef.reduce sums across data_group (which includes SDP), the + # metric is double-counted by sdp_size. Divide here so each SDP rank reports loss/sdp_size, + # making the SUM-reduction match a non-SDP run. Gradient is unaffected (each SDP rank + # contributes gradient from its own LOCAL tokens, no double-counting in the gradient buffer). + if sdp_group is not None: + loss = loss / torch.distributed.get_world_size(sdp_group) + + # Step 7: Gradient — broadcast segment-level factors to token level + if grad_output_scaled is not None and n_segs > 0: + # d(loss)/d(log_ratio_t) = -R_s * clip_factor_s / divisor (tok_sum cancels) + # clip_factor_s = clamp_min(A_s,0)*(R_s <= 1+eps_h) + clamp_max(A_s,0)*(R_s >= 1-eps_l) + clip_up = (R <= 1.0 + epsilon_high).float() + clip_dn = (R >= 1.0 - epsilon_low).float() + seg_grad = R * (A.clamp(min=0) * clip_up + A.clamp(max=0) * clip_dn) * valid.float() + + # Broadcast: each token gets its segment's gradient factor + token_grad = seg_grad[document_index] # (n_tokens,) + + # d(new_log_prob)/d(logits_k) = delta(k==target) - softmax_k (same chain rule as GRPO) + probability_ratio_grad = grad_output_scaled * token_grad * mask_float + + predicted_probabilities = exp_logits / sum_exp_logits.unsqueeze(-1) + grad = probability_ratio_grad.unsqueeze(-1) * predicted_probabilities.scatter_add( + -1, + target_masked.unsqueeze(-1), + -(loss_mask if target_mask is None else target_mask).unsqueeze(-1).to(torch.float32), + ) + grad = grad.to(logits.dtype) + + if grad_logits is None: + grad_logits = grad + else: + grad_logits.add_(grad) + + return loss, grad_logits, new_logprobs_mean diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 3cab2bca8..90b368e2b 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -39,6 +39,8 @@ def __init__( self._vocab_parallel = distributed_config.tensor_parallel > 1 and vocab_parallel self._sequence_parallel = distributed_config.sequence_tensor_parallel and not self._vocab_parallel self._parallel_dim = distributed_config.get_distributed_dim(DistributedDimNames.tensor) + self._sdp_dim = distributed_config.get_distributed_dim(DistributedDimNames.sequence_data) + self._sdp_active = distributed_config.sequence_data_parallel > 1 def forward_backward( self, diff --git a/tests/layers/test_docs_per_step.py b/tests/layers/test_docs_per_step.py new file mode 100644 index 000000000..b57c25057 --- /dev/null +++ b/tests/layers/test_docs_per_step.py @@ -0,0 +1,322 @@ +""" +Unit tests for docs_per_step / normalize_by_documents features. + +Covers: + 1. Divisor scaling in fused_grpo_loss_forward_backward and fused_gspo_loss_forward_backward + 2. normalize_by_documents flag in LanguageModelGRPOLoss (GRPO and GSPO policy_loss) + 3. Schedule._eff_depth_first / _eff_sequential_micro_batches / _eff_num_inputs properties + 4. Trainer._prefetch_to_doc_target accumulation logic +""" + +import dataclasses +import types + +import pytest +import torch + +from fast_llm.engine.schedule.config import ScheduleConfig +from fast_llm.engine.schedule.schedule import Schedule +from fast_llm.layers.language_model.config import LanguageModelKwargs +from fast_llm.layers.language_model.loss.config import LanguageModelGRPOLossConfig, LanguageModelLossKwargs +from fast_llm.layers.language_model.loss.grpo import ( + fused_grpo_loss_forward_backward, + fused_gspo_loss_forward_backward, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" +_atol = 1e-4 if device == "cuda" else 1e-5 + + +# --------------------------------------------------------------------------- +# 1. Divisor-scaling correctness in raw kernels +# --------------------------------------------------------------------------- + + +def test_grpo_divisor_scales_loss(): + """Halving the divisor should double the loss.""" + torch.manual_seed(10) + n_tok, vocab = 16, 32 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + + d1 = float(n_tok) + d2 = float(n_tok) * 2 + + loss1, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d1) + loss2, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_lp, divisor=d2) + + assert ( + abs(loss1.item() - 2.0 * loss2.item()) < _atol * 10 + ), f"Expected loss(d1) ≈ 2*loss(d2), got {loss1.item():.6f} vs {2*loss2.item():.6f}" + + +def test_gspo_divisor_scales_loss(): + """Halving the divisor should double the GSPO loss.""" + torch.manual_seed(11) + n_tok, vocab = 12, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + doc_idx = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 2, 2], dtype=torch.long, device=device) + + d1 = float(n_tok) + d2 = float(n_tok) * 2 + + loss1, _, _ = fused_gspo_loss_forward_backward( + logits, target, advantages, old_lp, doc_idx, divisor=d1, sdp_group=None + ) + loss2, _, _ = fused_gspo_loss_forward_backward( + logits, target, advantages, old_lp, doc_idx, divisor=d2, sdp_group=None + ) + + assert ( + abs(loss1.item() - 2.0 * loss2.item()) < _atol * 10 + ), f"Expected loss(d1) ≈ 2*loss(d2), got {loss1.item():.6f} vs {2*loss2.item():.6f}" + + +# --------------------------------------------------------------------------- +# 2. normalize_by_documents flag in LanguageModelGRPOLoss +# --------------------------------------------------------------------------- + + +def _make_grpo_loss(normalize_by_documents: bool, policy_loss: str = "grpo"): + """Instantiate a LanguageModelGRPOLoss with minimal (single-GPU) DistributedConfig.""" + from fast_llm.engine.distributed.config import DistributedConfig + from fast_llm.layers.language_model.loss.grpo import LanguageModelGRPOLoss + + dist_cfg = DistributedConfig() + cfg = LanguageModelGRPOLossConfig( + normalize_by_documents=normalize_by_documents, + policy_loss=policy_loss, + ) + return LanguageModelGRPOLoss(cfg, dist_cfg, name="grpo", prediction_distance=1, prediction_heads=1) + + +def _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs): + """Build the kwargs dict expected by LanguageModelGRPOLoss._forward_backward.""" + return { + LanguageModelLossKwargs.labels: [target], + LanguageModelLossKwargs.advantages: [advantages], + LanguageModelLossKwargs.old_log_probabilities: [old_lp], + LanguageModelLossKwargs.label_counts: [torch.full_like(target, n_labels, dtype=torch.int32)], + LanguageModelKwargs.num_labels_in_batch: [n_labels], + LanguageModelKwargs.num_documents_in_batch: n_docs, + LanguageModelKwargs.document_index: [doc_idx], + } + + +def test_normalize_by_documents_grpo(): + """normalize_by_documents=True → divisor=n_docs; False → divisor=n_labels. + + With n_docs ≠ n_labels, loss ratio must equal n_labels / n_docs. + """ + torch.manual_seed(20) + n_tok, vocab = 12, 16 + n_docs, n_labels = 3, n_tok + + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + + kwargs = _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs) + + loss_by_tokens, _ = _make_grpo_loss(normalize_by_documents=False)._forward_backward(logits, kwargs) + loss_by_docs, _ = _make_grpo_loss(normalize_by_documents=True)._forward_backward(logits, kwargs) + + expected_ratio = float(n_labels) / float(n_docs) + actual_ratio = loss_by_docs.item() / loss_by_tokens.item() + assert ( + abs(actual_ratio - expected_ratio) < 1e-4 + ), f"Expected loss_docs/loss_tokens ≈ {expected_ratio:.4f}, got {actual_ratio:.4f}" + + +def test_normalize_by_documents_gspo(): + """Same test for GSPO policy_loss.""" + torch.manual_seed(21) + n_tok, vocab = 12, 16 + n_docs, n_labels = 3, n_tok + + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_lp = torch.randn(n_tok, device=device) - 2.0 + # 3 equal segments → n_docs=3 + doc_idx = torch.cat([torch.full((n_tok // n_docs,), i, dtype=torch.long) for i in range(n_docs)]).to(device) + + kwargs = _make_grpo_kwargs(logits, target, advantages, old_lp, doc_idx, n_labels, n_docs) + + loss_by_tokens, _ = _make_grpo_loss(normalize_by_documents=False, policy_loss="gspo")._forward_backward( + logits, kwargs + ) + loss_by_docs, _ = _make_grpo_loss(normalize_by_documents=True, policy_loss="gspo")._forward_backward( + logits, kwargs + ) + + expected_ratio = float(n_labels) / float(n_docs) + actual_ratio = loss_by_docs.item() / loss_by_tokens.item() + assert ( + abs(actual_ratio - expected_ratio) < 1e-4 + ), f"Expected loss_docs/loss_tokens ≈ {expected_ratio:.4f}, got {actual_ratio:.4f}" + + +# --------------------------------------------------------------------------- +# 3. Schedule._eff_* properties +# --------------------------------------------------------------------------- + + +def _make_bare_schedule(depth_first: int, breadth_first: int, splits: int, override: int | None) -> Schedule: + """Create a Schedule with __init__ bypassed to test the _eff_* properties only.""" + config = ScheduleConfig( + depth_first_micro_batches=depth_first, + breadth_first_micro_batches=breadth_first, + micro_batch_splits=splits, + ) + sched = object.__new__(Schedule) + # Minimal attributes used by the three _eff_* properties. + object.__setattr__(sched, "_config", config) + object.__setattr__(sched, "_depth_first_override", override) + # samples_per_batch also needs _distributed_config.batch_data_parallel + fake_distributed = types.SimpleNamespace(batch_data_parallel=1) + object.__setattr__(sched, "_distributed_config", fake_distributed) + return sched + + +def test_schedule_eff_properties_no_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=None) + assert sched._eff_depth_first == 4 + assert sched._eff_sequential_micro_batches == 8 # 4 * 2 + assert sched._eff_num_inputs == 24 # 8 * 3 + assert sched.samples_per_batch == 8 # 8 * dp=1 + + +def test_schedule_eff_properties_with_override(): + sched = _make_bare_schedule(depth_first=4, breadth_first=2, splits=3, override=7) + assert sched._eff_depth_first == 7 # override wins + assert sched._eff_sequential_micro_batches == 14 # 7 * 2 + assert sched._eff_num_inputs == 42 # 14 * 3 + assert sched.samples_per_batch == 14 # 14 * dp=1 + + +def test_schedule_eff_properties_override_equals_config(): + """Override equal to config value → same result as no override.""" + sched_no = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=None) + sched_yes = _make_bare_schedule(depth_first=3, breadth_first=2, splits=1, override=3) + assert sched_no._eff_depth_first == sched_yes._eff_depth_first + assert sched_no._eff_sequential_micro_batches == sched_yes._eff_sequential_micro_batches + assert sched_no._eff_num_inputs == sched_yes._eff_num_inputs + + +def test_schedule_samples_per_batch_uses_eff(): + """samples_per_batch should scale with _eff_sequential, not config.sequential.""" + sched = _make_bare_schedule(depth_first=2, breadth_first=2, splits=1, override=5) + # Config says depth_first=2 → sequential=4; override=5 → eff_sequential=10 + assert sched._eff_sequential_micro_batches == 10 + assert sched.samples_per_batch == 10 # dp=1 + + +# --------------------------------------------------------------------------- +# 4. _prefetch_to_doc_target accumulation logic +# --------------------------------------------------------------------------- + + +@dataclasses.dataclass +class _FakeMicrobatch: + """Stub for a single split of one microbatch.""" + + num_documents: int + num_documents_in_batch: int | None = None + + @classmethod + def share_batch_data(cls, inputs, distributed): + """Mimic TokenModelInput.share_batch_data with group=None (single process).""" + if inputs[0].num_documents_in_batch is None: + total = sum(inp.num_documents for inp in inputs) + for inp in inputs: + inp.num_documents_in_batch = total + + +def _fake_iterator(doc_counts: list[int]): + """Yield [_FakeMicrobatch(n)] for each n in doc_counts.""" + for n in doc_counts: + yield [_FakeMicrobatch(num_documents=n)] + + +class _StubTrainer: + """Concrete stub that exposes only the interface _prefetch_to_doc_target needs.""" + + # Borrow the method directly so it runs against this stub's attributes. + from fast_llm.engine.training.trainer import Trainer as _Trainer + + _prefetch_to_doc_target = _Trainer._prefetch_to_doc_target + + +def _make_fake_trainer(docs_per_step: int, bfmb: int = 1): + """Create a _StubTrainer with the attributes _prefetch_to_doc_target reads.""" + schedule_cfg = types.SimpleNamespace( + docs_per_step=docs_per_step, + breadth_first_micro_batches=bfmb, + ) + config = types.SimpleNamespace(schedule=schedule_cfg) + distributed = types.SimpleNamespace(batch_data_group=None) + + trainer = _StubTrainer() + trainer._config = config + trainer._distributed = distributed + return trainer + + +def test_prefetch_stops_at_target(): + """Buffer should stop growing once cumulative docs ≥ docs_per_step.""" + trainer = _make_fake_trainer(docs_per_step=6, bfmb=1) + # Each microbatch has 2 docs; need ≥6 → expect 3 microbatches + it = _fake_iterator([2, 2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + + assert len(buffer) == 3, f"Expected 3 microbatches, got {len(buffer)}" + + +def test_prefetch_resets_num_documents_in_batch(): + """After the call, every microbatch input has num_documents_in_batch = step total.""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + # 3 docs, 3 docs → total=6 (overshoots 5, stops after 2nd) + it = _fake_iterator([3, 3, 3]) + buffer = trainer._prefetch_to_doc_target(it) + + step_total = sum(mb[0].num_documents for mb in buffer) + for mb in buffer: + for mi in mb: + assert ( + mi.num_documents_in_batch == step_total + ), f"Expected num_documents_in_batch={step_total}, got {mi.num_documents_in_batch}" + + +def test_prefetch_overshoot_is_included(): + """A microbatch that pushes the total over the target IS included (not dropped).""" + trainer = _make_fake_trainer(docs_per_step=5, bfmb=1) + it = _fake_iterator([4, 4]) # 4 < 5, then 8 ≥ 5 → 2 microbatches + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2 + assert buffer[-1][0].num_documents_in_batch == 8 # step total = 4+4 + + +def test_prefetch_divisibility_check(): + """Raises when fetched count is not divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # Each microbatch has 5 docs → only 1 mb needed, but 1 % 2 != 0 + it = _fake_iterator([5, 5, 5]) + with pytest.raises(Exception): + trainer._prefetch_to_doc_target(it) + + +def test_prefetch_exact_divisibility(): + """No error when fetched count is exactly divisible by breadth_first_micro_batches.""" + trainer = _make_fake_trainer(docs_per_step=4, bfmb=2) + # 2 docs each → need ≥4 → fetch 2 microbatches → 2 % 2 == 0 + it = _fake_iterator([2, 2, 2, 2]) + buffer = trainer._prefetch_to_doc_target(it) + assert len(buffer) == 2 diff --git a/tests/layers/test_gspo_loss.py b/tests/layers/test_gspo_loss.py new file mode 100644 index 000000000..f5f302113 --- /dev/null +++ b/tests/layers/test_gspo_loss.py @@ -0,0 +1,460 @@ +""" +Unit tests for fused_gspo_loss_forward_backward. + +Tests: single segment, multi-segment packed, GRPO/GSPO equivalence at ratio=1, +segment-level clipping, SDP mock, gradient check, extra metrics unchanged. +""" + +import math + +import torch + +from fast_llm.layers.language_model.loss.grpo import ( + fused_grpo_loss_forward_backward, + fused_gspo_loss_forward_backward, +) + +device = "cuda" if torch.cuda.is_available() else "cpu" +atol = 1e-4 if device == "cuda" else 1e-5 + + +# --------------------------------------------------------------------------- +# Reference GSPO implementation +# --------------------------------------------------------------------------- + + +def _gspo_reference(logits, target, advantages, old_log_probs, doc_idx, eps_lo, eps_hi, divisor): + """Pure-PyTorch reference without compilation or distributed calls.""" + loss_mask = target >= 0 + log_softmax = torch.log_softmax(logits.float(), dim=-1) + new_log_probs = log_softmax.gather(-1, (target * loss_mask).unsqueeze(-1)).squeeze(-1) + log_ratio = (new_log_probs - old_log_probs.float()) * loss_mask.float() + + n_segs = int(doc_idx.max().item()) + 1 + lrn_sum = torch.zeros(n_segs, dtype=torch.float32) + adv_sum = torch.zeros(n_segs, dtype=torch.float32) + tok_sum = torch.zeros(n_segs, dtype=torch.float32) + for i in range(len(target)): + if loss_mask[i]: + s = doc_idx[i].item() + lrn_sum[s] += log_ratio[i].item() + adv_sum[s] += advantages[i].item() + tok_sum[s] += 1.0 + + loss = 0.0 + for s in range(n_segs): + if tok_sum[s] == 0: + continue + R = math.exp(lrn_sum[s] / tok_sum[s]) + A = adv_sum[s] / tok_sum[s] + R_clipped = max(1.0 - eps_lo, min(1.0 + eps_hi, R)) + surr1 = R * A + surr2 = R_clipped * A + loss += -min(surr1, surr2) * tok_sum[s] + return loss / divisor + + +# --------------------------------------------------------------------------- +# Test 1: single segment +# --------------------------------------------------------------------------- + + +def test_single_segment(): + torch.manual_seed(0) + n_tok, vocab = 8, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = ( + torch.log_softmax(torch.randn(n_tok, vocab, device=device), dim=-1) + .gather(-1, target.unsqueeze(-1)) + .squeeze(-1) + ) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 2: multi-segment packed +# --------------------------------------------------------------------------- + + +def test_multi_segment_packed(): + torch.manual_seed(1) + # 3 segments of lengths [5, 7, 4] + segs = [5, 7, 4] + n_tok = sum(segs) + vocab = 32 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = ( + torch.log_softmax(torch.randn(n_tok, vocab, device=device), dim=-1) + .gather(-1, target.unsqueeze(-1)) + .squeeze(-1) + ) + doc_idx = torch.cat([torch.full((l,), i, dtype=torch.long) for i, l in enumerate(segs)]).to(device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol * 3, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 3: GRPO vs GSPO equivalence when all tokens in a segment have ratio=1 +# --------------------------------------------------------------------------- + + +def test_ratio_one_matches_grpo(): + """When new == old log-probs (ratio=1 everywhere), GRPO and GSPO losses match.""" + torch.manual_seed(2) + n_tok, vocab = 12, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + # Set old log probs equal to new log probs for ratio=1 + old_log_probs = torch.log_softmax(logits.float(), dim=-1).gather(-1, target.unsqueeze(-1)).squeeze(-1).detach() + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_grpo, _, _ = fused_grpo_loss_forward_backward(logits, target, advantages, old_log_probs, divisor=divisor) + loss_gspo, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + # At ratio=1, GRPO loss = sum_t -A_t * mask_t / divisor (no clipping) + # GSPO loss = sum_s tok_s * -A_s / divisor (weighted per segment) + # For a single segment: GSPO = -mean(A) * N / divisor = same total + assert abs(loss_grpo.item() - loss_gspo.item()) < atol, f"grpo={loss_grpo.item()}, gspo={loss_gspo.item()}" + + +# --------------------------------------------------------------------------- +# Test 4: segment-level clipping (GSPO clips whole segment, not per-token) +# --------------------------------------------------------------------------- + + +def test_segment_level_clipping(): + """ + Construct a segment where per-token ratios straddle the clip boundary (some high, some low), + but the geometric mean ratio is in-range. GSPO should NOT clip; GRPO should clip some tokens. + """ + torch.manual_seed(3) + vocab = 8 + # 4 tokens, alternating log_ratio +0.5 and -0.5 → mean = 0 → R = exp(0) = 1.0 (in range) + n_tok = 4 + target = torch.zeros(n_tok, dtype=torch.long, device=device) + advantages = torch.ones(n_tok, device=device) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + + # Build logits such that new_log_probs - old_log_probs alternates +0.4 and -0.4 + # Use constant logits; set old_log_probs manually + logits = torch.zeros(n_tok, vocab, device=device) + old_log_probs = torch.tensor([0.4, -0.4, 0.4, -0.4], device=device) # per-token log_ratio = 0 - old + + eps = 0.2 + divisor = float(n_tok) + loss_gspo, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + epsilon_low=eps, + epsilon_high=eps, + divisor=divisor, + sdp_group=None, + ) + + # GSPO: mean log_ratio = mean of (log_softmax(0)[0] - old_log_probs) + # R = exp(mean), A=1.0 + # As long as R is in [1-eps, 1+eps], loss = -R * 1 * 4 / 4 = -R + new_log_probs = torch.log_softmax(logits.float(), dim=-1)[:, 0] + log_ratios = new_log_probs - old_log_probs + mean_log_ratio = log_ratios.mean().item() + R = math.exp(mean_log_ratio) + expected = -R # unclipped, weight 4/divisor = 1 + assert abs(loss_gspo.item() - expected) < atol, f"gspo={loss_gspo.item()}, expected={expected}" + + +# --------------------------------------------------------------------------- +# Test 5: masked tokens don't contribute +# --------------------------------------------------------------------------- + + +def test_masked_tokens(): + torch.manual_seed(4) + n_tok, vocab = 10, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + target[3] = -100 # mask token 3 + target[7] = -100 # mask token 7 + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + doc_idx = torch.zeros(n_tok, dtype=torch.long, device=device) + divisor = float(n_tok) + + loss_actual, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_actual.item() - loss_ref) < atol, f"{loss_actual.item()} vs {loss_ref}" + + +# --------------------------------------------------------------------------- +# Test 6: SDP mock — split tokens across 2 "ranks", verify correctness +# --------------------------------------------------------------------------- + + +def test_sdp_mock(): + """ + Simulate SDP=2: split tokens in half, compute per-rank scatter_add, manually all-reduce, + then verify the combined sums match the full-batch computation. + """ + torch.manual_seed(5) + segs = [6, 5, 7] # 3 segments + n_tok = sum(segs) + vocab = 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + doc_idx = torch.cat([torch.full((l,), i, dtype=torch.long) for i, l in enumerate(segs)]).to(device) + divisor = float(n_tok) + + # Full-batch reference loss + loss_full, _, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + + # Simulate SDP=2: split at midpoint + mid = n_tok // 2 + loss_r0_only, _, _ = fused_gspo_loss_forward_backward( + logits[:mid], + target[:mid], + advantages[:mid], + old_log_probs[:mid], + doc_idx[:mid], + divisor=divisor, + sdp_group=None, + ) + loss_r1_only, _, _ = fused_gspo_loss_forward_backward( + logits[mid:], + target[mid:], + advantages[mid:], + old_log_probs[mid:], + doc_idx[mid:], + divisor=divisor, + sdp_group=None, + ) + # These individual ranks do NOT give the right answer (segments are split) + # But the full-batch result should match the reference + loss_ref = _gspo_reference( + logits.cpu(), + target.cpu(), + advantages.cpu(), + old_log_probs.cpu(), + doc_idx.cpu(), + 0.2, + 0.2, + divisor, + ) + assert abs(loss_full.item() - loss_ref) < atol * 3, f"full={loss_full.item()}, ref={loss_ref}" + + # When sdp_group is None but we manually pre-sum, the result should also match + # (This conceptually validates the all-reduce logic without actual distributed calls) + log_softmax_full = torch.log_softmax(logits.float(), dim=-1) + new_lp_full = log_softmax_full.gather(-1, (target * (target >= 0)).unsqueeze(-1)).squeeze(-1) + log_ratio_full = (new_lp_full - old_log_probs.float()) * (target >= 0).float() + + n_segs = 3 + lrn_r0 = torch.zeros(n_segs) + adv_r0 = torch.zeros(n_segs) + tok_r0 = torch.zeros(n_segs) + lrn_r1 = torch.zeros(n_segs) + adv_r1 = torch.zeros(n_segs) + tok_r1 = torch.zeros(n_segs) + for i in range(mid): + if target[i] >= 0: + s = doc_idx[i].item() + lrn_r0[s] += log_ratio_full[i].item() + adv_r0[s] += advantages[i].item() + tok_r0[s] += 1 + for i in range(mid, n_tok): + if target[i] >= 0: + s = doc_idx[i].item() + lrn_r1[s] += log_ratio_full[i].item() + adv_r1[s] += advantages[i].item() + tok_r1[s] += 1 + + # Manually all-reduce (SUM) + lrn_global = lrn_r0 + lrn_r1 + adv_global = adv_r0 + adv_r1 + tok_global = tok_r0 + tok_r1 + + loss_manual = 0.0 + for s in range(n_segs): + if tok_global[s] == 0: + continue + R = math.exp(lrn_global[s] / tok_global[s]) + A = adv_global[s] / tok_global[s] + R_c = max(1 - 0.2, min(1 + 0.2, R)) + loss_manual += -min(R * A, R_c * A) * tok_global[s] + loss_manual /= divisor + + assert abs(loss_full.item() - loss_manual) < atol * 3, f"full={loss_full.item()}, manual={loss_manual}" + + +# --------------------------------------------------------------------------- +# Test 7: gradient correctness via finite differences +# --------------------------------------------------------------------------- + + +def test_gradient_finite_diff(): + torch.manual_seed(6) + n_tok, vocab = 6, 8 + logits = torch.randn(n_tok, vocab, dtype=torch.float64) + target = torch.randint(0, vocab, (n_tok,)) + advantages = torch.randn(n_tok, dtype=torch.float64) + old_log_probs = torch.randn(n_tok, dtype=torch.float64) + doc_idx = torch.tensor([0, 0, 0, 1, 1, 1], dtype=torch.long) + divisor = float(n_tok) + eps = 1e-5 + + grad_logits = torch.zeros_like(logits) + _, grad_out, _ = fused_gspo_loss_forward_backward( + logits, + target, + advantages, + old_log_probs, + doc_idx, + grad_logits=grad_logits, + grad_output=1.0, + divisor=divisor, + sdp_group=None, + ) + + # Finite-difference gradient for one entry + i, k = 2, 3 + logits_p = logits.clone() + logits_p[i, k] += eps + logits_m = logits.clone() + logits_m[i, k] -= eps + loss_p, _, _ = fused_gspo_loss_forward_backward( + logits_p, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + loss_m, _, _ = fused_gspo_loss_forward_backward( + logits_m, + target, + advantages, + old_log_probs, + doc_idx, + divisor=divisor, + sdp_group=None, + ) + fd_grad = (loss_p.item() - loss_m.item()) / (2 * eps) + + assert abs(grad_out[i, k].item() - fd_grad) < 1e-4, f"analytical={grad_out[i, k].item():.6f}, fd={fd_grad:.6f}" + + +# --------------------------------------------------------------------------- +# Test 8: extra metrics unchanged by policy_loss choice +# --------------------------------------------------------------------------- + + +def test_extra_metrics_are_per_token(): + """Extra metrics are per-token regardless of GSPO/GRPO — computed from token-level ratios.""" + from fast_llm.layers.language_model.loss.grpo import compute_grpo_metrics + + torch.manual_seed(7) + n_tok, vocab = 10, 16 + logits = torch.randn(n_tok, vocab, device=device) + target = torch.randint(0, vocab, (n_tok,), device=device) + advantages = torch.randn(n_tok, device=device) + old_log_probs = torch.randn(n_tok, device=device) + label_counts = torch.full((n_tok,), n_tok, dtype=torch.int32, device=device) + + metrics = compute_grpo_metrics( + logits, + target, + advantages, + old_log_probs, + label_counts, + epsilon_low=0.2, + epsilon_high=0.2, + logits_scale_factor=1.0, + group=None, + ) + for attr in ("old_logprobs", "ratio_new_old", "kl_new_old", "advantage"): + val = getattr(metrics, attr) + assert val.isfinite(), f"{attr} is not finite: {val}"