Skip to content
Open
1 change: 1 addition & 0 deletions fast_llm/data/document/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ class LanguageModelBatchPreprocessingConfig(TokenPreprocessingConfig):
use_preference_spans: bool = Field(default=False)
use_grpo_data: bool = Field(default=False)
return_label_counts: bool = Field(default=False)
return_document_index: bool = Field(default=False)

def _validate(self) -> None:
super()._validate()
Expand Down
13 changes: 10 additions & 3 deletions fast_llm/data/document/language_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ class LanguageModelTargetInput(ModelInput):
advantages: torch.Tensor | None = None
old_log_probabilities: torch.Tensor | None = None
label_counts: torch.Tensor | None = None
document_index: torch.Tensor | None = None
num_labels: int | None = None
num_labels_in_batch: int | None = None

Expand Down Expand Up @@ -84,6 +85,7 @@ def to_kwargs(self) -> dict[str, typing.Any]:
LanguageModelKwargs.advantages: [target.advantages for target in self.targets],
LanguageModelKwargs.old_log_probabilities: [target.old_log_probabilities for target in self.targets],
LanguageModelKwargs.label_counts: [target.label_counts for target in self.targets],
LanguageModelKwargs.document_index: [target.document_index for target in self.targets],
LanguageModelKwargs.num_labels_in_batch: [target.num_labels_in_batch for target in self.targets],
}
if self.image_patches is not None:
Expand Down Expand Up @@ -177,7 +179,11 @@ def _set_target_inputs(
document_begin += length

mask = labels >= 0
label_counts = self._get_label_counts(mask) if config.return_label_counts else None
label_counts, document_index = (
self._get_label_counts(mask)
if config.return_label_counts or config.return_document_index
else (None, None)
)

for input_index, model_input in enumerate(model_inputs):
label_end = model_input.sequence_k_dim.size + prediction_distance
Expand All @@ -188,6 +194,7 @@ def _set_target_inputs(
tokens=labels[label_begin:label_end].clone(),
mask=mask[label_begin:label_end] if config.return_prediction_mask else None,
label_counts=label_counts[label_begin:label_end] if config.return_label_counts else None,
document_index=document_index[label_begin:label_end] if config.return_document_index else None,
# Set value for the first input only so `share_batch_data` generated the correct sum.
# TODO: ====== Make optional?
num_labels=(
Expand All @@ -202,7 +209,7 @@ def _set_target_inputs(

model_input.targets.append(target_input)

def _get_label_counts(self, mask: torch.Tensor):
def _get_label_counts(self, mask: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# Count the number of non-masked labels in each document through cumulative sums.
mask_cumsum = torch.cat([mask.new_zeros(1), mask.cumsum(0)])
length_cumsum = torch.tensor([0] + self.lengths, device=self.device).cumsum(0)
Expand All @@ -214,4 +221,4 @@ def _get_label_counts(self, mask: torch.Tensor):
document_index = torch.searchsorted(
length_cumsum[1:], torch.arange(len(mask), device=self.device), side="right"
)
return labels_per_document[document_index]
return labels_per_document[document_index], document_index
10 changes: 10 additions & 0 deletions fast_llm/engine/schedule/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@ class ScheduleConfig(Config):
hint=FieldHint.core,
valid=check_field(Assert.gt, 0),
)
docs_per_step: int = Field(
default=0,
desc="Target number of documents (rollouts) per optimizer step, globally across all data-parallel ranks. "
"When >0, each training step dynamically accumulates microbatches until the globally all-reduced "
"document count reaches this value, then triggers the optimizer step. "
"depth_first_micro_batches is ignored when this is set. "
"0 = use depth_first_micro_batches as-is (fixed microbatch count per step).",
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
breadth_first_micro_batches: int = Field(
default=1,
desc="Number of micro-batches processed breadth-first, i.e., interleaved across model stages.",
Expand Down
5 changes: 3 additions & 2 deletions fast_llm/engine/schedule/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,8 @@ def _preprocess_data(
if context.schedule.phase.is_training
else None
)
model_inputs = [next(data_iterator) for _ in range(self._config.sequential_micro_batches)]
n_micro_batches = context.schedule._eff_sequential_micro_batches
model_inputs = [next(data_iterator) for _ in range(n_micro_batches)]
model_inputs[0][0].share_batch_data(
[model_input for model_inputs_ in model_inputs for model_input in model_inputs_], self._distributed
)
Expand All @@ -336,7 +337,7 @@ def _preprocess_data(
extra_kwargs={
"grad_output": grad_output,
"micro_batch": micro_batch,
"num_micro_batches": self._config.sequential_micro_batches,
"num_micro_batches": n_micro_batches,
"micro_batch_splits": self._config.micro_batch_splits,
},
)
Expand Down
38 changes: 25 additions & 13 deletions fast_llm/engine/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,17 @@ def __init__(
batch_meta: list[ModelInput],
distributed_config: DistributedConfig,
phase: PhaseType,
_depth_first_override: int | None = None,
):
super().__init__(config)
self._depth_first_override = _depth_first_override
self._multi_stage = multi_stage
self._distributed_config = distributed_config
self._num_stages = len(self._multi_stage.stages)
self._phase = phase
self._is_training = self._phase.is_training

if self._config.num_inputs < self._distributed_config.pipeline_parallel:
if self._eff_num_inputs < self._distributed_config.pipeline_parallel:
warnings.warn("Not enough input to achieve true pipeline parallelism.")

# Setup the activation metas.
Expand Down Expand Up @@ -155,9 +157,25 @@ def __init__(
def phase(self) -> PhaseType:
return self._phase

@property
def _eff_depth_first(self) -> int:
return (
self._depth_first_override
if self._depth_first_override is not None
else self._config.depth_first_micro_batches
)

@property
def _eff_sequential_micro_batches(self) -> int:
return self._eff_depth_first * self._config.breadth_first_micro_batches

@property
def _eff_num_inputs(self) -> int:
return self._eff_sequential_micro_batches * self._config.micro_batch_splits

@property
def samples_per_batch(self) -> int:
return self._config.sequential_micro_batches * self._distributed_config.batch_data_parallel
return self._eff_sequential_micro_batches * self._distributed_config.batch_data_parallel

def iterate(self, pipeline_rank: int | None = None) -> typing.Iterator[Step]:
return iter(self._steps if pipeline_rank is None else self._device_steps[pipeline_rank])
Expand Down Expand Up @@ -189,7 +207,7 @@ def _create_index(self) -> None:
Assert.in_range(
step.index,
0,
self._config.num_inputs,
self._eff_num_inputs,
)
Assert.incl(step.type_, (StepType.forward, StepType.backward))
step.global_index = i
Expand All @@ -205,7 +223,7 @@ def _create_index(self) -> None:
Assert.custom(all, self._device_steps)
# Consistency checks
step_map = self._step_map.copy()
for data_index in range(self._config.num_inputs):
for data_index in range(self._eff_num_inputs):
for type_ in (StepType.forward, StepType.backward):
for stage in range(0 if type_ == StepType.forward else self._first_grad_stage, self._num_stages):
assert (
Expand Down Expand Up @@ -470,14 +488,11 @@ def _create_steps(self) -> tuple[list[Step], int]:
first_grad_stage += 1
else:
first_grad_stage = self._num_stages
for depth_first_micro_batch in range(self._config.depth_first_micro_batches):
for depth_first_micro_batch in range(self._eff_depth_first):
for stage in range(self._num_stages):
for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches):
for micro_batch_split in range(self._config.micro_batch_splits):
micro_batch = (
breadth_first_micro_batch * self._config.depth_first_micro_batches
+ depth_first_micro_batch
)
micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch
steps.append(
Step(
stage=stage,
Expand All @@ -492,10 +507,7 @@ def _create_steps(self) -> tuple[list[Step], int]:
for stage in reversed(range(first_grad_stage, self._num_stages)):
for breadth_first_micro_batch in range(self._config.breadth_first_micro_batches):
for micro_batch_split in reversed(range(self._config.micro_batch_splits)):
micro_batch = (
breadth_first_micro_batch * self._config.depth_first_micro_batches
+ depth_first_micro_batch
)
micro_batch = breadth_first_micro_batch * self._eff_depth_first + depth_first_micro_batch
steps.append(
Step(
stage=stage,
Expand Down
62 changes: 55 additions & 7 deletions fast_llm/engine/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,13 @@ def setup(self, distributed: Distributed, run: Run) -> None:
preprocessing_config = self._multi_stage.get_preprocessing_config(
PhaseType.training, self._config.schedule.micro_batch_splits
)
self._preprocessing_config = preprocessing_config
self._single_mb_meta = preprocessing_config.get_input_meta(self._data.config.micro_batch_size)
self._schedule_cache: dict[int, Schedule] = {}
self._schedule = Schedule(
config=self._config.schedule,
multi_stage=self._multi_stage,
batch_meta=preprocessing_config.get_input_meta(self._data.config.micro_batch_size),
batch_meta=self._single_mb_meta,
distributed_config=self._config.model.distributed,
phase=PhaseType.training,
)
Expand All @@ -140,6 +143,41 @@ def setup(self, distributed: Distributed, run: Run) -> None:

self._is_setup = True

def _get_or_build_schedule(self, n_microbatches: int) -> Schedule:
if n_microbatches not in self._schedule_cache:
bfmb = self._config.schedule.breadth_first_micro_batches
depth_first = n_microbatches // bfmb
self._schedule_cache[n_microbatches] = Schedule(
config=self._config.schedule,
multi_stage=self._multi_stage,
batch_meta=self._single_mb_meta,
distributed_config=self._config.model.distributed,
phase=PhaseType.training,
_depth_first_override=depth_first,
)
return self._schedule_cache[n_microbatches]

def _prefetch_to_doc_target(self, data_iterator) -> list:
target = self._config.schedule.docs_per_step
bfmb = self._config.schedule.breadth_first_micro_batches
buffer = []
total_docs = 0
while total_docs < target:
mb = next(data_iterator)
mb[0].share_batch_data(mb, self._distributed)
total_docs += mb[0].num_documents_in_batch
buffer.append(mb)
Assert.eq(
len(buffer) % bfmb,
0,
msg=f"Fetched {len(buffer)} microbatches not divisible by breadth_first_micro_batches={bfmb}",
)
# Reset num_documents_in_batch to the step total on all microbatches
for mb in buffer:
for mi in mb:
mi.num_documents_in_batch = total_docs
return buffer

@abc.abstractmethod
def _get_data(self) -> Data:
pass
Expand Down Expand Up @@ -220,12 +258,22 @@ def _train(self) -> tuple[bool, dict[PhaseType, dict[str, typing.Any]]]:

# TODO: Data loader hates getting all micro-batches at once.
# (Also preprocessing adds overhead)
reduced_losses, update_successful, train_metrics = self._runner.run_step(
train_iterator,
self._schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)
if self._config.schedule.docs_per_step > 0:
buffer = self._prefetch_to_doc_target(train_iterator)
step_schedule = self._get_or_build_schedule(len(buffer))
reduced_losses, update_successful, train_metrics = self._runner.run_step(
iter(buffer),
step_schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)
else:
reduced_losses, update_successful, train_metrics = self._runner.run_step(
train_iterator,
self._schedule,
iteration=self._completed_steps,
return_metrics=is_logging,
)

# Advanced, skipped, and Nan iterations.
if update_successful:
Expand Down
5 changes: 4 additions & 1 deletion fast_llm/functional/triton/grpo_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def triton_grpo_loss_forward_backward(
logits_scale_factor: float = 1.0,
num_labels_in_seq: torch.Tensor | None = None,
divisor: float | None = None,
grad_divisor: float | None = None, # Optional separate divisor for the gradient (defaults to divisor)
block_size: int | None = None,
num_warps: int | None = None,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Expand All @@ -148,6 +149,8 @@ def triton_grpo_loss_forward_backward(
n_cols = logits.size(-1)
if divisor is None:
divisor = n_rows
if grad_divisor is None:
grad_divisor = divisor
if block_size is None:
block_size = min(triton.next_power_of_2(n_cols), 32768)
if num_warps is None:
Expand All @@ -171,7 +174,7 @@ def triton_grpo_loss_forward_backward(
grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits
backward_kwargs = {
"grad_logits_ptr": grad_logits,
"grad_losses": grad_output / divisor,
"grad_losses": grad_output / grad_divisor,
"grad_logits_stride_0": grad_logits.stride(-2),
"accumulate": accumulate,
}
Expand Down
8 changes: 8 additions & 0 deletions fast_llm/layers/language_model/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class LanguageModelKwargs(LanguageModelLossKwargs):
sample_map = "sample_map"
embedding_map = "embedding_map"
num_documents_in_batch = "num_documents_in_batch"
document_index = "document_index"
# TODO: These are generic
phase = "phase"
loss_mask = "loss_mask"
Expand Down Expand Up @@ -119,6 +120,13 @@ class LanguageModelHeadConfig(BlockConfig):
hint=FieldHint.feature,
valid=check_field(Assert.geq, 0),
)
fp32_lm_head: bool = Field(
default=False,
desc="Upcast input and weight to float32 before the lm_head linear. "
"Matches vLLM's bf16_last_layer_fp32 quantization so new_logprobs and old_logprobs "
"are computed at the same numerical precision, keeping the IS ratio near 1 at init.",
hint=FieldHint.feature,
)
prediction_heads: int = Field(
default=1,
desc="Prediction heads.",
Expand Down
32 changes: 27 additions & 5 deletions fast_llm/layers/language_model/head.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from fast_llm.layers.language_model.loss.config import LanguageModelLabelEntropyLossConfig
from fast_llm.layers.language_model.loss.loss import LanguageModelLoss
from fast_llm.tensor import TensorMeta
from fast_llm.tensor import TensorMeta, accumulate_gradient
from fast_llm.utils import Assert, safe_merge_dicts

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -242,9 +242,17 @@ def _logits_loss_forward_backward_partial(
split_index: int = 0,
return_logits: bool = False,
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
if self._config.fp32_lm_head:
input_dtype = input_.dtype
input_ = input_.to(torch.float32)
# detach → requires_grad=False → output_parallel_linear_backward skips weight grad
weight = self.output_weights.detach().to(torch.float32)
else:
weight = self.output_weights

logits, context = output_parallel_linear_forward(
input_=input_,
weight=self.output_weights,
weight=weight,
bias=None,
group=self._parallel_dim.group if self._vocab_parallel else None,
sequence_parallel=self._sequence_parallel and self._vocab_parallel,
Expand Down Expand Up @@ -273,9 +281,23 @@ def _logits_loss_forward_backward_partial(
if loss_value is not None:
losses_.append(loss_value.detach())

return sum(losses_) if losses_ else None, (
output_parallel_linear_backward(grad, context) if self.training else None
)
if not self.training or grad is None:
return sum(losses_) if losses_ else None, None

input_grad = output_parallel_linear_backward(grad, context)
if self._config.fp32_lm_head:
# Weight grad was skipped because weight.requires_grad=False; accumulate manually.
# context: (input_, weight, bias, group, sequence_parallel, ...)
saved_input = context[0]
if context[4]: # sequence_parallel
from fast_llm.core.ops import gather_op

saved_input = gather_op(saved_input, context[3], dim=0)
grad_weight = grad.flatten(0, -2).t().mm(saved_input.flatten(0, -2))
accumulate_gradient(self.output_weights, grad_weight.to(self.output_weights.dtype))
input_grad = input_grad.to(input_dtype)

return sum(losses_) if losses_ else None, input_grad

def get_loss_definitions(self) -> list[LossDef]:
return [
Expand Down
Loading