From 9a3e86d343dffc716053ebcbff99d0a55a1151d4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 29 Jan 2026 16:05:09 -0500 Subject: [PATCH 01/22] Entropy loss tweaks --- fast_llm/core/distributed.py | 39 +- fast_llm/core/ops.py | 7 +- fast_llm/data/{ => data}/data_loader.py | 0 fast_llm/data/data/gpt/data.py | 2 +- fast_llm/data/preprocessing/tokenizer.py | 22 +- fast_llm/data/sample/patch.py | 2 +- fast_llm/data/sample/range.py | 2 +- fast_llm/data/sample/token.py | 2 +- fast_llm/engine/checkpoint/config.py | 3 +- fast_llm/engine/multi_stage/fsdp.py | 3 +- fast_llm/engine/multi_stage/multi_stage.py | 3 +- fast_llm/functional/entropy_loss.py | 240 ++++++------ fast_llm/functional/triton/cross_entropy.py | 13 +- fast_llm/layers/language_model/loss/dpo.py | 18 +- .../language_model/loss/entropy_loss.py | 70 +++- fast_llm/layers/language_model/loss/loss.py | 2 +- fast_llm/layers/language_model/loss/z_loss.py | 48 ++- fast_llm/models/gpt/config.py | 5 + fast_llm/models/gpt/trainer.py | 4 +- tests/functional/test_entropy_loss.py | 178 --------- tests/functional/test_functional.py | 56 --- tests/layers/test_lm_losses.py | 346 ++++++++++++++++++ tests/models/test_checkpoint.py | 4 - tests/models/test_model.py | 2 - tests/utils/dataset.py | 6 +- tests/utils/subtest.py | 21 +- 26 files changed, 658 insertions(+), 440 deletions(-) rename fast_llm/data/{ => data}/data_loader.py (100%) delete mode 100644 tests/functional/test_entropy_loss.py create mode 100644 tests/layers/test_lm_losses.py diff --git a/fast_llm/core/distributed.py b/fast_llm/core/distributed.py index da443c4f6..16f7d92c8 100644 --- a/fast_llm/core/distributed.py +++ b/fast_llm/core/distributed.py @@ -15,7 +15,6 @@ import torch import torch.monitor -from torch._C._distributed_c10d import Work from torch.distributed import ( # noqa ProcessGroup, ReduceOp, @@ -29,6 +28,15 @@ logger = logging.getLogger(__name__) +def _get_device(group: ProcessGroup) -> torch.device: + if torch.distributed.is_nccl_available() and isinstance(group, torch.distributed.ProcessGroupNCCL): + return torch.device(torch.cuda.current_device()) + elif isinstance(group, torch.distributed.ProcessGroupGloo): + return torch.device("cpu") + else: + raise NotImplementedError(type(group)) + + @contextlib.contextmanager def set_timeout(group: ProcessGroup | None, timeout: float | None = None): if group is not None and timeout is not None: @@ -42,7 +50,7 @@ def set_timeout(group: ProcessGroup | None, timeout: float | None = None): def broadcast( tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, timeout: float | None = None -) -> Work | None: +) -> torch.distributed.Work | None: """Same as torch.distributed.broadcast, but without the complication of going through the global rank.""" assert group is not None opts = torch.distributed.BroadcastOptions() @@ -72,12 +80,10 @@ def check_parallel_match(tensor: torch.Tensor, group: ProcessGroup | None, name: ) -def safe_barrier( - group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None, device: torch.device | None = None -) -> None: +def safe_barrier(group: ProcessGroup | None, value: int | str = 1, timeout: float | None = None) -> None: if group: hashed = hash(value) % 2**32 - out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout, device=device) + out = allreduce_scalar(hashed, dtype=torch.int64, group=group, timeout=timeout) if out != hashed * group.size(): raise RuntimeError(f"Desync detected for barrier {value} ({out}!={hashed*group.size()})") @@ -88,10 +94,9 @@ def allreduce_scalar( group: torch.distributed.ProcessGroup | None = None, op=ReduceOp.SUM, timeout: float | None = None, - device: torch.device | None = None, ) -> float | int: if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device() if device is None else device) + value = torch.full([1], value, dtype=dtype, device=_get_device(group)) with set_timeout(group, timeout): torch.distributed.all_reduce(value, op=op, group=group) return value.item() @@ -106,7 +111,7 @@ def all_gather_scalar( timeout: float | None = None, ): if group: - value = torch.full([1], value, dtype=dtype, device=torch.cuda.current_device()) + value = torch.full([1], value, dtype=dtype, device=_get_device(group)) output_tensor = value.new_empty((group.size(),)) with set_timeout(group, timeout): torch.distributed.all_gather_into_tensor(output_tensor, value, group=group) @@ -116,7 +121,7 @@ def all_gather_scalar( def broadcast_scalar( - value: float | int, + value: float | int | None, dtype: torch.dtype = torch.float64, group: torch.distributed.ProcessGroup | None = None, src: int = 0, @@ -124,7 +129,7 @@ def broadcast_scalar( ) -> float | int: if not group: return value - tensor = torch.empty([1], dtype=dtype, device=torch.device(torch.cuda.current_device())) + tensor = torch.empty([1], dtype=dtype, device=torch.device(_get_device(group))) if group.rank() == src: tensor.fill_(value) broadcast(tensor, src, group, timeout=timeout) @@ -141,19 +146,21 @@ def broadcast_object(input_object: typing.Any | None, group: ProcessGroup | None if group.rank() == src: tensor = _object_to_tensor(input_object) size = tensor.numel() - broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + broadcast_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group)) broadcast_tensor.copy_(tensor) broadcast_scalar(size, torch.int64, group, src) broadcast(broadcast_tensor, src, group) return input_object else: size = int(broadcast_scalar(None, torch.int64, group, src)) - output_tensor = torch.empty(size, dtype=torch.uint8, device=torch.cuda.current_device()) + output_tensor = torch.empty(size, dtype=torch.uint8, device=_get_device(group)) broadcast(output_tensor, src, group) return _tensor_to_object(output_tensor) -def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: +def send( + tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, tag: int = 0 +) -> torch.distributed.Work | None: assert group is not None if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": # send not supported for gloo on GPU. @@ -169,7 +176,9 @@ def send(tensor: torch.Tensor, dst: int, group: ProcessGroup, async_op=False, ta return None -def recv(tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0) -> Work | None: +def recv( + tensor: torch.Tensor, src: int, group: ProcessGroup, async_op=False, tag: int = 0 +) -> torch.distributed.Work | None: assert group is not None if isinstance(group, torch.distributed.ProcessGroupGloo) and tensor.device.type != "cpu": # recv not supported for gloo on GPU. diff --git a/fast_llm/core/ops.py b/fast_llm/core/ops.py index bb61aadd0..7d361a22e 100644 --- a/fast_llm/core/ops.py +++ b/fast_llm/core/ops.py @@ -8,7 +8,6 @@ import torch import torch._dynamo # noqa import torch.autograd -from torch._C._distributed_c10d import Work from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_gather_into_tensor, all_reduce, reduce_scatter_tensor from fast_llm.utils import Assert, div @@ -18,7 +17,7 @@ def reduce_op( input_: torch.Tensor, group: ProcessGroup | None, *, op: ReduceOp = ReduceOp.SUM, async_op: bool = False -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor: if group: handle = all_reduce(input_, group=group, async_op=async_op, op=op) else: @@ -62,7 +61,7 @@ def swap_mult_dim(tensor: torch.Tensor, factor: int, old_dim: int, new_dim: int) def gather_op( input_: torch.Tensor, group: ProcessGroup | None, dim: int, async_op: bool = False, out=None -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor: """Gather tensors and concatenate along the last dimension.""" # Bypass the function if we are using only 1 GPU. if not group: @@ -89,7 +88,7 @@ def reduce_scatter_op( op: ReduceOp = ReduceOp.SUM, dim: int = 0, async_op: bool = False, -) -> tuple[torch.Tensor, Work] | torch.Tensor: +) -> tuple[torch.Tensor, torch.distributed.Work] | torch.Tensor: """Reduce-scatter the input tensor across model parallel group.""" # Bypass the function if we are using only 1 GPU. if not group: diff --git a/fast_llm/data/data_loader.py b/fast_llm/data/data/data_loader.py similarity index 100% rename from fast_llm/data/data_loader.py rename to fast_llm/data/data/data_loader.py diff --git a/fast_llm/data/data/gpt/data.py b/fast_llm/data/data/gpt/data.py index e572e8e61..17f151919 100644 --- a/fast_llm/data/data/gpt/data.py +++ b/fast_llm/data/data/gpt/data.py @@ -8,8 +8,8 @@ from fast_llm.core.distributed import safe_barrier from fast_llm.data.data.abstract import Data +from fast_llm.data.data.data_loader import SampledDatasetIterator from fast_llm.data.data.gpt.config import GPTDataConfig -from fast_llm.data.data_loader import SampledDatasetIterator from fast_llm.data.dataset.abstract import SampledDataset from fast_llm.data.dataset.config import SamplingParameters from fast_llm.data.dataset.gpt.config import GPTSamplingData diff --git a/fast_llm/data/preprocessing/tokenizer.py b/fast_llm/data/preprocessing/tokenizer.py index 157744f51..4408ca772 100644 --- a/fast_llm/data/preprocessing/tokenizer.py +++ b/fast_llm/data/preprocessing/tokenizer.py @@ -11,6 +11,7 @@ if typing.TYPE_CHECKING: import numpy as np import torch + import transformers @config_class(dynamic_type={PreprocessingConfig: "tokenizer"}) @@ -52,7 +53,7 @@ def __init__(self, config: ConfigType): from transformers import AutoTokenizer log_main_rank(f"> loading tokenizer from {config.path} ...") - self.tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer: "transformers.PreTrainedTokenizer" = AutoTokenizer.from_pretrained( pretrained_model_name_or_path=self._config.path, errors="replace", max_len=None, @@ -70,10 +71,15 @@ def __init__(self, config: ConfigType): @functools.cached_property def vocab_size(self) -> int: - out = len(self.tokenizer) - if self._config.max_vocab_size is not None: - out = min(out, self._config.max_vocab_size) - return out + return ( + self._tokenizer_vocab_size + if self._config.max_vocab_size is None + else min(self._tokenizer_vocab_size, self._config.max_vocab_size) + ) + + @functools.cached_property + def _tokenizer_vocab_size(self) -> int: + return len(self.tokenizer) @property def vocab(self) -> dict[str, int]: @@ -99,7 +105,11 @@ def tokenize( tokens = ( torch.tensor( tokens, - dtype=torch.int64 if len(self.tokenizer) > torch.iinfo(data_type.torch).max else data_type.torch, + dtype=( + torch.int64 + if self._tokenizer_vocab_size > torch.iinfo(data_type.torch).max + else data_type.torch + ), ) % self._config.max_vocab_size ).to(data_type.torch) diff --git a/fast_llm/data/sample/patch.py b/fast_llm/data/sample/patch.py index 7ae537104..32ea60cb8 100644 --- a/fast_llm/data/sample/patch.py +++ b/fast_llm/data/sample/patch.py @@ -85,7 +85,7 @@ def __len__(self) -> int: return self.sample_size def get_padding(self, size: int) -> typing.Self: - return PatchSample( + return self.__class__( self.patches.new_empty((0, *self.patches.shape[1:])), self.token_map.new_empty(0), self.positions.new_empty([0, self.patches.ndim - 2]), diff --git a/fast_llm/data/sample/range.py b/fast_llm/data/sample/range.py index 53683342a..f57ee04d9 100644 --- a/fast_llm/data/sample/range.py +++ b/fast_llm/data/sample/range.py @@ -52,7 +52,7 @@ def __len__(self) -> int: return self.sample_size def get_padding(self, size: int) -> typing.Self: - return RangeSample([], size) + return self.__class__([], size) class RangeBatch(Batch): diff --git a/fast_llm/data/sample/token.py b/fast_llm/data/sample/token.py index cd4d7fa02..6ab55dbba 100644 --- a/fast_llm/data/sample/token.py +++ b/fast_llm/data/sample/token.py @@ -58,7 +58,7 @@ def __len__(self) -> int: return len(self.tokens) def get_padding(self, size: int) -> typing.Self: - return TokenSample(torch.full([size], -100, dtype=self.tokens.dtype), [size]) + return self.__class__(torch.full([size], -100, dtype=self.tokens.dtype), [size]) class TokenBatch(Batch): diff --git a/fast_llm/engine/checkpoint/config.py b/fast_llm/engine/checkpoint/config.py index 3f1970538..98303539e 100644 --- a/fast_llm/engine/checkpoint/config.py +++ b/fast_llm/engine/checkpoint/config.py @@ -141,11 +141,12 @@ class CheckpointSaveConfigBase(CheckpointConfigBase): @config_class() class CheckpointStateSaveConfigBase(CheckpointSaveConfigBase, CheckpointStateConfigBase): + _abstract = False model_weights: bool = FieldUpdate(desc="Save the model weights.") optimizer_state: bool = FieldUpdate(desc="Save the optimizer state. Default: save if supported by the `format`.") def _validate(self) -> None: - if self.optimizer_state is None: + if self.optimizer_state is None and hasattr(self.format, "support_optimizer"): with self._set_implicit_default(): # TODO: Make sure it's a type self.optimizer_state = self.format.support_optimizer diff --git a/fast_llm/engine/multi_stage/fsdp.py b/fast_llm/engine/multi_stage/fsdp.py index ae37410ae..f84f36309 100644 --- a/fast_llm/engine/multi_stage/fsdp.py +++ b/fast_llm/engine/multi_stage/fsdp.py @@ -4,7 +4,6 @@ import typing import torch -from torch._C._distributed_c10d import ReduceOp from torch.distributed import all_reduce, reduce_scatter_tensor from fast_llm.core.distributed import ProcessGroup @@ -398,7 +397,7 @@ def reduce_gradients( out, self._grad_buffer, group=self._fsdp_group, - op=ReduceOp.AVG, + op=torch.distributed.ReduceOp.AVG, ) if accumulate: triton_add(self._grad_shard, out, self._grad_shard) diff --git a/fast_llm/engine/multi_stage/multi_stage.py b/fast_llm/engine/multi_stage/multi_stage.py index 698f62daa..ed293b103 100644 --- a/fast_llm/engine/multi_stage/multi_stage.py +++ b/fast_llm/engine/multi_stage/multi_stage.py @@ -4,7 +4,6 @@ import warnings import torch -from torch._C._distributed_c10d import ProcessGroup from fast_llm.config import Configurable from fast_llm.engine.base_model.base_model import BaseModel @@ -611,7 +610,7 @@ class TiedParameter: # Whether the local rank is involved at all. on_device: bool # Process group for reduction. - group: ProcessGroup | None = dataclasses.field(repr=False, init=False) + group: torch.distributed.ProcessGroup | None = dataclasses.field(repr=False, init=False) all_ranks: set[int] # The index of the main stage. main_stage: int diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index f1212f4b8..4bdf87c3f 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -1,31 +1,42 @@ import torch from fast_llm.core.distributed import ProcessGroup, ReduceOp, all_reduce -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat -from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward +from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.utils import Assert -def _torch_entropy_loss_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, +@torch.compile +def torch_entropy_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) or (*batch, vocab) + loss_mask: torch.Tensor | None, # (*batch,) grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, entropy_loss_type: EntropyLossType, temperature: float = 1.0, -) -> tuple[torch.Tensor, torch.Tensor | None]: +) -> tuple[torch.Tensor, torch.Tensor | None]: # (), (*batch, vocab) """ A wrapper for the pytorch implementation of cross-entropy. The cross-entropy kernels themselves are well-optimized, but the need for explicit casting and separate forward and backward kernels lead to poor performance. - TODO: loss masking only works for with labels format and if the masking index is set to -100. """ + + # Torch methods require flattened batch dimension. + target = target.flatten() if target_format == TargetFormat.labels else target.flatten(0, -2) + if target_format == TargetFormat.labels: + assert loss_mask is None + loss_mask = target >= 0 + else: + target = target.float() + if loss_mask is not None: + loss_mask = loss_mask.flatten() + # Torch compile doesn't understand this. with torch.set_grad_enabled(grad_output is not None): logits_ = logits.float().detach().requires_grad_(grad_output is not None) - logits_scaled = logits_ if logits_scale_factor == 1.0 else logits_ * logits_scale_factor + + logits_scaled = (logits_ if logits_scale_factor == 1.0 else logits_ * logits_scale_factor).flatten(0, -2) if target_format == TargetFormat.logits: target_scale = logits_scale_factor / temperature target = target if target_scale == 1.0 else target * target_scale @@ -35,9 +46,7 @@ def _torch_entropy_loss_forward_backward( if entropy_loss_type == EntropyLossType.cross_entropy: if target_format == TargetFormat.logits: target = torch.softmax(target, dim=-1) - loss = torch.nn.functional.cross_entropy( - logits_scaled, target, reduction="mean" if loss_mask is None else "none" - ) + loss = torch.nn.functional.cross_entropy(logits_scaled, target, reduction="none") else: predicted_log_probability = torch.nn.functional.log_softmax(logits_scaled, dim=-1) if target_format == TargetFormat.logits: @@ -45,30 +54,33 @@ def _torch_entropy_loss_forward_backward( elif target_format == TargetFormat.probabilities: target_log_probability = target.log() else: - target_log_probability = ( - torch.nn.functional.one_hot(target, num_classes=logits_scaled.size(-1)).add(1.0e-10).log() + target_probability = torch.nn.functional.one_hot( + torch.clamp_min(target, 0), num_classes=logits_scaled.size(-1) ) + if loss_mask is not None: + target_probability = target_probability * loss_mask.unsqueeze(-1) + target_log_probability = target_probability.add(1.0e-10).log() if entropy_loss_type == EntropyLossType.forward_kl: loss = torch.nn.functional.kl_div( predicted_log_probability, target_log_probability, - reduction="batchmean" if loss_mask is None else "none", + reduction="none", log_target=True, ) elif entropy_loss_type == EntropyLossType.reverse_kl: loss = torch.nn.functional.kl_div( target_log_probability, predicted_log_probability, - reduction="batchmean" if loss_mask is None else "none", + reduction="none", log_target=True, ) else: raise NotImplementedError(entropy_loss_type) - if loss_mask is not None: - loss = loss.sum(dim=-1) + loss = loss.sum(dim=-1) if loss_mask is not None: - loss = (loss * loss_mask).mean() + loss = loss * loss_mask + loss = loss.mean() if grad_output is None: grad = None @@ -79,42 +91,54 @@ def _torch_entropy_loss_forward_backward( @torch.compile -def _fused_softmax_base( - logits: torch.Tensor, logits_scale_factor: float = 1.0, group: ProcessGroup | None = None, dim: int = -1 -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: +def fused_softmax_base( + logits: torch.Tensor, # (*batch, vocab) + logits_scale_factor: float = 1.0, + group: ProcessGroup | None = None, + dim: int = -1, +) -> tuple[ + torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor +]: # (*batch, vocab), (*batch, vocab), (*batch,), (*batch,) + """ + Calculate the required inputs for softmax computation, mainly sum_exp_logits, + in a numerically stable way and with tensor-parallel support. + Warning: The returned values are regularized by `logits_max`. + The regularization typically but not always cancels out in derived quantities. + """ logits = logits.float() if logits_scale_factor != 1.0: logits = logits * logits_scale_factor - logits_max = torch.max(logits, dim=dim, keepdim=True)[0] + logits_max = logits.max(dim=dim)[0] if group is not None: all_reduce(logits_max, op=ReduceOp.MAX, group=group) - logits_norm = (logits - logits_max).float() + logits_norm = (logits - logits_max.unsqueeze(-1)).float() exp_logits = logits_norm.exp() - sum_exp_logits = exp_logits.sum(dim=dim, keepdim=True) + sum_exp_logits = exp_logits.sum(dim=dim) if group is not None: all_reduce(sum_exp_logits, op=ReduceOp.SUM, group=group) - return logits_norm, exp_logits, sum_exp_logits + return logits_norm, exp_logits, sum_exp_logits, logits_max @torch.compile def _fused_reverse_kl_base( - logits: torch.Tensor, - target: torch.Tensor, + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch, vocab) grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, temperature: float = 1.0, -): - logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) - predicted_log_probability = logits_norm - sum_exp_logits.log() - predicted_probability = exp_logits / sum_exp_logits +) -> tuple[torch.Tensor, torch.Tensor | None]: # (*batch,), (*batch, vocab) + assert target_format in (TargetFormat.logits, TargetFormat.probabilities) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_log_probability = logits_norm - sum_exp_logits.log().unsqueeze(-1) + predicted_probability = exp_logits / sum_exp_logits.unsqueeze(-1) if target_format == TargetFormat.logits: - target_logits_norm, _, sum_exp_target_logits = _fused_softmax_base( + target_logits_norm, _, sum_exp_target_logits, _ = fused_softmax_base( target, logits_scale_factor / temperature, group ) - target_log_probability = target_logits_norm - sum_exp_target_logits.log() + target_log_probability = target_logits_norm - sum_exp_target_logits.log().unsqueeze(-1) else: target_log_probability = torch.log(target) @@ -137,32 +161,32 @@ def _fused_reverse_kl_base( @torch.compile def _fused_cross_entropy_base( - logits: torch.Tensor, - target: torch.Tensor, + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch, vocab) grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, group: ProcessGroup | None = None, temperature: float = 1.0, return_kl_loss: bool = False, -): - logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) +) -> tuple[torch.Tensor, torch.Tensor | None]: # (*batch,), (*batch, vocab) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) if target_format == TargetFormat.logits: - target_logits_norm, exp_logits_targets, sum_exp_target_logits = _fused_softmax_base( + target_logits_norm, exp_logits_targets, sum_exp_target_logits, _ = fused_softmax_base( target, logits_scale_factor / temperature, group ) - target = exp_logits_targets / sum_exp_target_logits + target = exp_logits_targets / sum_exp_target_logits.unsqueeze(-1) # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) # KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities)) if return_kl_loss: if target_format == TargetFormat.logits: - target_log_probability = target_logits_norm - sum_exp_target_logits.log() + target_log_probability = target_logits_norm - sum_exp_target_logits.log().unsqueeze(-1) else: target_log_probability = torch.log(target) logits_norm = logits_norm - target_log_probability - predicted_logits = (target * logits_norm).sum(dim=-1, keepdim=True) + predicted_logits = (target * logits_norm).sum(dim=-1) if group is not None: # We need to sum the over the tensor-parallel group, # but this is handled in the final averaging provided we multiply by the group size. @@ -174,41 +198,63 @@ def _fused_cross_entropy_base( grad = None else: # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - grad = (exp_logits - sum_exp_logits * target) * (grad_output / sum_exp_logits) + grad = (exp_logits - sum_exp_logits.unsqueeze(-1) * target) * (grad_output / sum_exp_logits.unsqueeze(-1)) return per_sample_loss, grad @torch.compile -def _fused_cross_entropy_base_from_labels( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor, - grad_output: float | None, - logits_scale_factor: float, +def fused_predicted_logits_from_labels( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + loss_mask: torch.Tensor, # (*batch,), == target>=0 group: ProcessGroup | None = None, -): - logits_norm, exp_logits, sum_exp_logits = _fused_softmax_base(logits, logits_scale_factor, group) +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # (*batch,), (*batch,), (*batch,) + """ + Recover the value of the logits at the target index, with support for masking (target < 0) and tensor parallelism. + In the simple case, equivalent to `logits.gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1)` - target = target.unsqueeze(-1) + Normally used in combination with `fused_softmax_base`, may also recover probabilities or log probabilities: + `predicted_probabilities = predicted_logits.exp() / sum_exp_logits` + `predicted_log_probabilities = predicted_logits / sum_exp_logits.log()` + """ if group is None: # Keep values within range for scatter and gather ops to work. - target = target * loss_mask.unsqueeze(-1) + target_masked = target * loss_mask target_mask = None else: - # Mask the target (fused) + # Mask the target (fused). # TODO: Could mask earlier on cpu or overlap with reduce? vocab_start_index = logits.size(-1) * group.rank() target_mask = (target >= vocab_start_index) * (target < vocab_start_index + logits.size(-1)) - target = (target - vocab_start_index) * target_mask + target_masked = (target - vocab_start_index) * target_mask # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) # KL loss is the same because P * log(P) == 0. - predicted_logits = logits_norm.gather(1, target) + predicted_logits = logits.gather(-1, target_masked.unsqueeze(-1)).squeeze(-1) if group is not None: predicted_logits = target_mask * predicted_logits all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) + return predicted_logits, target_masked, target_mask + + +@torch.compile +def _fused_cross_entropy_base_from_labels( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) + loss_mask: torch.Tensor, # (*batch,) + grad_output: float | None, + logits_scale_factor: float, + group: ProcessGroup | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None]: # (*batch,), (*batch, vocab) + logits_norm, exp_logits, sum_exp_logits, _ = fused_softmax_base(logits, logits_scale_factor, group) + predicted_logits, target_masked, target_mask = fused_predicted_logits_from_labels( + logits_norm, target, loss_mask, group + ) + + # CE loss = mean(log(sum_exp_logits) - sum(probabilities * logits)) + # KL loss is the same because P * log(P) == 0. per_sample_loss = sum_exp_logits.log() - predicted_logits if grad_output is None: @@ -216,17 +262,19 @@ def _fused_cross_entropy_base_from_labels( else: # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. grad = exp_logits.scatter_add( - 1, target, -sum_exp_logits if target_mask is None else -(target_mask * sum_exp_logits) - ) * (grad_output / sum_exp_logits) + -1, + target_masked.unsqueeze(-1), + -sum_exp_logits.unsqueeze(-1) if target_mask is None else -(target_mask * sum_exp_logits).unsqueeze(-1), + ) * (grad_output / sum_exp_logits.unsqueeze(-1)) return per_sample_loss, grad @torch.compile -def _fused_entropy_loss_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, +def fused_entropy_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) or (*batch, vocab) + loss_mask: torch.Tensor | None, # (*batch,) grad_output: float | None, logits_scale_factor: float, target_format: TargetFormat, @@ -239,11 +287,11 @@ def _fused_entropy_loss_forward_backward( It is an improvement over the pytorch implementation because of the fused casting, both in speed and memory, but still suboptimal because it needs multiple kernels. """ - grad_output = None if grad_output is None else grad_output / logits.size(0) * logits_scale_factor + grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor if target_format == TargetFormat.labels: assert entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl) - if loss_mask is None: - loss_mask = target >= 0 + assert loss_mask is None + loss_mask = target >= 0 per_sample_loss, grad = _fused_cross_entropy_base_from_labels( logits, target, @@ -277,7 +325,7 @@ def _fused_entropy_loss_forward_backward( raise NotImplementedError(entropy_loss_type) if loss_mask is not None: - per_sample_loss = per_sample_loss * loss_mask.unsqueeze(-1) + per_sample_loss = per_sample_loss * loss_mask loss = per_sample_loss.mean() if grad is not None: @@ -286,63 +334,3 @@ def _fused_entropy_loss_forward_backward( grad = grad.to(logits.dtype) return loss, grad - - -_ENTROPY_LOSS_IMPLEMENTATIONS = { - EntropyLossImplementation.torch: _torch_entropy_loss_forward_backward, - EntropyLossImplementation.fused: _fused_entropy_loss_forward_backward, - EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, -} - - -def entropy_loss_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - group: ProcessGroup | None = None, - implementation: EntropyLossImplementation = EntropyLossImplementation.fused, - logits_scale_factor: float = 1.0, - temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, - entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Select the appropriate implementation of cross-entropy. - The triton implementation from the triton submodule is the fastest and recommended one. - It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, - which is faster and has a relatively small memory overhead. - """ - if target_format == TargetFormat.labels: - Assert.eq(target.shape, logits.shape[:-1]) - Assert.eq(target.dtype, torch.int64) - assert loss_mask is None - else: - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - if group: - Assert.eq(implementation, EntropyLossImplementation.fused) - return _fused_entropy_loss_forward_backward( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - group, - temperature, - ) - else: - return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - temperature=temperature, - ) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 709d0c52d..ef2039ade 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -140,7 +140,8 @@ def triton_cross_entropy_forward_backward( # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() - n_rows, n_cols = logits.shape + n_rows = logits.shape[:-1].numel() + n_cols = logits.size(-1) block_size = triton.next_power_of_2(n_cols) assert block_size <= TritonConfig.MAX_BLOCK_SIZE_BYTES num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) @@ -155,8 +156,8 @@ def triton_cross_entropy_forward_backward( losses, None if grad_output is None else grad_output / n_rows, n_cols, - logits.stride(0), - None if grad_output is None else grad_logits.stride(0), + logits.stride(-2), + None if grad_output is None else grad_logits.stride(-2), logits_scale_factor, block_size=block_size, num_warps=num_warps, @@ -172,9 +173,9 @@ def triton_cross_entropy_forward_backward( losses, None if grad_output is None else grad_output / n_rows, n_cols, - logits.stride(0), - target.stride(0), - None if grad_output is None else grad_logits.stride(0), + logits.stride(-2), + target.stride(-2), + None if grad_output is None else grad_logits.stride(-2), logits_scale_factor, block_size=block_size, num_warps=num_warps, diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py index 15c4c788c..2194c6f86 100644 --- a/fast_llm/layers/language_model/loss/dpo.py +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -49,17 +49,18 @@ def dpo_loss( beta: float = 1.0, logits_scale_factor: float = 1.0, ) -> torch.Tensor: + logits = logits.float() if logits_scale_factor != 1.0: # TODO: Make more efficient. logits = logits * logits_scale_factor - policy_log_probabilities = _get_target_log_probabilities(logits, targets) + policy_log_probabilities = get_target_log_probabilities(logits, targets) policy_log_ratios = _get_target_log_probability_for_spans( policy_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(policy_log_probabilities, rejected_spans) - reference_log_probabilities = _get_target_log_probabilities(reference_model_logits.float().detach(), targets) + reference_log_probabilities = get_target_log_probabilities(reference_model_logits.float().detach(), targets) reference_log_ratios = _get_target_log_probability_for_spans( reference_log_probabilities, chosen_spans ) - _get_target_log_probability_for_spans(reference_log_probabilities, rejected_spans) @@ -68,14 +69,17 @@ def dpo_loss( return -torch.nn.functional.logsigmoid(beta * (policy_log_ratios - reference_log_ratios)).mean() -def _get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor): - # Gather log probabilities corresponding to the target tokens - return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - - def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): return sum( log_probabilities[sample_index, begin:end].sum() for sample_index, sample_spans in enumerate(spans) for begin, end in sample_spans ) + + +@torch.compile +def get_target_log_probabilities(logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + # Avoid negative (masked) labels. + targets = targets * (targets >= 0) + # Gather log probabilities corresponding to the target tokens + return torch.nn.functional.log_softmax(logits, dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 3ae87d2e9..1dfd3920c 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -3,15 +3,17 @@ import torch from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig -from fast_llm.functional.entropy_loss import entropy_loss_forward_backward +from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward +from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward from fast_llm.layers.language_model.loss.config import ( LanguageModelDistillationLossConfig, LanguageModelLabelEntropyLossConfig, ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss +from fast_llm.utils import Assert -def _get_imlementation( +def _get_implementation( default: EntropyLossImplementation = EntropyLossImplementation.auto, loss_type: EntropyLossType = EntropyLossType.cross_entropy, vocab_parallel: bool = False, @@ -34,7 +36,7 @@ def _get_imlementation( class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._implementation = _get_imlementation( + self._implementation = _get_implementation( self._config.implementation, self._config.loss_type, self._vocab_parallel ) @@ -63,7 +65,7 @@ def __init__(self, *args, **kwargs): if self._prediction_distance > 0: raise NotImplementedError() - self._implementation = _get_imlementation( + self._implementation = _get_implementation( self._config.implementation, self._config.loss_type, self._vocab_parallel ) @@ -84,3 +86,63 @@ def forward_backward( target_format=TargetFormat.logits, entropy_loss_type=self._config.loss_type, ) + + +_ENTROPY_LOSS_IMPLEMENTATIONS = { + EntropyLossImplementation.torch: torch_entropy_loss_forward_backward, + EntropyLossImplementation.fused: fused_entropy_loss_forward_backward, + EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, +} + + +def entropy_loss_forward_backward( + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) or (*batch, vocab) + loss_mask: torch.Tensor | None, # (*batch,) + grad_output: float | None, + group: torch.distributed.ProcessGroup | None = None, + implementation: EntropyLossImplementation = EntropyLossImplementation.fused, + logits_scale_factor: float = 1.0, + temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Select the appropriate implementation of cross-entropy. + The triton implementation from the triton submodule is the fastest and recommended one. + It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, + which is faster and has a relatively small memory overhead. + """ + if target_format == TargetFormat.labels: + Assert.eq(target.shape, logits.shape[:-1]) + Assert.eq(target.dtype, torch.int64) + assert loss_mask is None + else: + Assert.eq(target.shape, logits.shape) + assert target.dtype.is_floating_point, target.dtype + if loss_mask is not None: + Assert.eq(loss_mask.shape, logits.shape[:-1]) + if group: + Assert.eq(implementation, EntropyLossImplementation.fused) + return fused_entropy_loss_forward_backward( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + group, + temperature, + ) + else: + return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + temperature=temperature, + ) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 711560a8f..41e8942ac 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -116,6 +116,6 @@ def loss_forward_backward( grad = None else: loss.backward(torch.full_like(loss, grad_output)) - grad = input_.grad.detach().to(input_.dtype) + grad = input_.grad.detach() return loss, grad diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index c94851bf2..82b8d5318 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -2,8 +2,9 @@ import torch +from fast_llm.functional.entropy_loss import fused_softmax_base from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig -from fast_llm.layers.language_model.loss.loss import LanguageModelLoss, loss_forward_backward +from fast_llm.layers.language_model.loss.loss import LanguageModelLoss class LanguageModelZLoss[ConfigType: LanguageModelZLossConfig](LanguageModelLoss[ConfigType]): @@ -19,12 +20,12 @@ def forward_backward( kwargs: dict[str, typing.Any], split_index: int = 0, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - return loss_forward_backward( - self._get_grad_output(kwargs), - z_loss, + return z_loss_forward_backward( logits, self._get_loss_mask(kwargs, split_index), - self._logits_scale_factor, + grad_output=self._get_grad_output(kwargs), + group=self._parallel_dim.group if self._vocab_parallel else None, + logits_scale_factor=self._logits_scale_factor, ) @@ -34,10 +35,41 @@ def z_loss( loss_mask: "torch.Tensor | None" = None, logits_scale_factor: float = 1.0, ) -> torch.Tensor: - """ - Z-loss = mean(logsumexp(logits, dim=-1) ** 2) - """ + # TODO: Replace usage in MoE, move to testing. + logits = logits.float() out = torch.logsumexp(logits if logits_scale_factor == 1.0 else logits * logits_scale_factor, dim=-1) ** 2 if loss_mask is not None: out = out * loss_mask return torch.mean(out) + + +@torch.compile +def z_loss_forward_backward( + logits: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: torch.distributed.ProcessGroup | None = None, + logits_scale_factor: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + Z-loss = mean(logsumexp(logits, dim=-1) ** 2) + Grad = 2 * log_sum_exp_logits * softmax(logits) + """ + grad_output = None if grad_output is None else grad_output / logits.shape[:-1].numel() * logits_scale_factor + logits_norm, exp_logits, sum_exp_logits, logits_max = fused_softmax_base(logits, logits_scale_factor, group) + log_sum_exp_logits = sum_exp_logits.log() + logits_max + + per_sample_loss = log_sum_exp_logits**2 + if loss_mask is not None: + per_sample_loss = per_sample_loss * loss_mask + loss = per_sample_loss.mean() + + if grad_output is None: + grad = None + else: + grad_base = 2 * grad_output * (log_sum_exp_logits / sum_exp_logits) + if loss_mask is not None: + grad_base = grad_base * loss_mask + grad = (grad_base.unsqueeze(-1) * exp_logits).to(logits.dtype) + + return loss, grad diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index a315beecc..314741c3b 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -54,6 +54,11 @@ class GPTBatchConfig(BatchConfig): desc="Read loss masking spans from the dataset.", hint=FieldHint.feature, ) + use_preference_spans: bool = Field( + default=False, + desc="Read dpo data (chosen and rejected spans) from the dataset.", + hint=FieldHint.feature, + ) truncate_documents: bool | None = Field( default=True, desc=( diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index 768d3fdd7..ded0f81c8 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -33,11 +33,11 @@ def _get_sampling_parameters( def _get_preprocessing_config( self, *, _return_dict: bool = False ) -> LanguageModelPreprocessingConfig | dict[str, typing.Any]: + out = { "type": "language_model", "vocab_size": self._config.model.base_model.embeddings.vocab_size, "use_loss_masking_spans": self._config.batch.use_loss_masking_spans, - # OK since DPO is not supported for MTP. - "use_preference_spans": getattr(self._config.model.base_model.head, "enable_dpo", False), + "use_preference_spans": self._config.batch.use_preference_spans, } return out if _return_dict else LanguageModelPreprocessingConfig.from_dict(out) diff --git a/tests/functional/test_entropy_loss.py b/tests/functional/test_entropy_loss.py deleted file mode 100644 index 35d1ef648..000000000 --- a/tests/functional/test_entropy_loss.py +++ /dev/null @@ -1,178 +0,0 @@ -import pathlib - -import pytest -import torch - -from fast_llm.engine.distributed.config import DistributedBackend -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig -from fast_llm.functional.entropy_loss import entropy_loss_forward_backward -from fast_llm.utils import Assert -from tests.utils.subtest import DistributedTestContext - - -def _get_cross_entropy_inputs( - num_columns: int, loss_masking: bool, target_format: TargetFormat -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - device = "cuda" if torch.cuda.is_available() else "cpu" - # We want something moderately close to the target for the test to be meaningful - logits_var = torch.randn(256, num_columns, dtype=torch.float32, device=device) / 3 - loss_mask = torch.randint(0, 2, (256,), dtype=torch.bool, device=device) if loss_masking else None - if target_format == TargetFormat.labels: - target = torch.randint(0, num_columns, (256,), dtype=torch.int64, device=device) - logits = torch.nn.functional.one_hot(target, num_columns) + logits_var - if loss_masking: - logits = torch.where(loss_mask.unsqueeze(-1), logits, -100) - loss_mask = None - else: - target = torch.randn(256, num_columns, dtype=torch.float32, device=device) - logits = target + logits_var - if target_format == TargetFormat.probabilities: - target = torch.softmax(target, -1) - return logits, target, loss_mask - - -def _compare_entropy_loss_outputs( - loss: torch.Tensor, - ref_loss: torch.Tensor, - has_grad: bool, - grad: torch.Tensor | None, - ref_grad: torch.Tensor | None, - threshold=1e-5, - loss_min_threshold=1e-6, -): - Assert.rms_close_relative(loss, ref_loss, threshold, loss_min_threshold) - if has_grad: - Assert.rms_close_relative(grad, ref_grad, threshold, 1e-8) - else: - assert grad is None - assert ref_grad is None - - -@pytest.mark.slow -@pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking"), - ( - (8192, 1.0, 1.0, False), # Simple - (5000, 1.0, 1.0, False), # Not a power of 2 - (5000, None, 1.0, False), # No grad - (5000, 1.0, 4.0, False), # Loss scaling - (5000, 4.0, 1.0, False), # Grad scaling - (5000, 1.0, 1.0, True), # Loss masking - (65536, 1.0, 1.0, False), # Max block size - (65537, 1.0, 1.0, False), # Above max block size - ), -) -@pytest.mark.parametrize("target_format", TargetFormat) -@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) -def test_entropy_loss(num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type): - if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - pytest.skip(reason="Not implemented") - # TODO: Test tensor-parallel implementation. - logits, target, loss_mask = _get_cross_entropy_inputs(num_columns, loss_masking, target_format) - kwargs = { - "logits": logits, - "target": target, - "loss_mask": loss_mask, - "grad_output": grad_output, - "logits_scale_factor": logits_scale_factor, - "target_format": target_format, - "entropy_loss_type": entropy_loss_type, - } - # Torch serves as the reference implementation. - out_torch, grad_torch = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.torch) - out_fused, grad_fused = entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.fused) - - _compare_entropy_loss_outputs( - out_fused, - out_torch, - grad_output is not None, - grad_fused, - grad_torch, - loss_min_threshold=5e-6, - ) - - if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available(): - # Triton implementation only supports cross-entropy. - return - assert TritonConfig.TRITON_ENABLED - if num_columns > 65536: - with pytest.raises(AssertionError): - entropy_loss_forward_backward(**kwargs, implementation=EntropyLossImplementation.triton) - else: - out_triton, grad_triton = entropy_loss_forward_backward( - **kwargs, implementation=EntropyLossImplementation.triton - ) - _compare_entropy_loss_outputs(out_triton, out_torch, grad_output is not None, grad_triton, grad_torch) - - -def _entropy_loss_distributed( - target_format: TargetFormat, - entropy_loss_type: EntropyLossType, - loss_masking: bool, - group: torch.distributed.ProcessGroup, -): - # Ensure all workers have the same inputs. - torch.manual_seed(0) - rank = group.rank() - world_size = group.size() - logits, target, loss_mask = _get_cross_entropy_inputs(1000, loss_masking, target_format) - - kwargs = { - "loss_mask": loss_mask, - "grad_output": 1.0, - "target_format": target_format, - "implementation": EntropyLossImplementation.fused, - "entropy_loss_type": entropy_loss_type, - } - out_ref, grad_ref = entropy_loss_forward_backward(logits, target, **kwargs) - - out, grad = entropy_loss_forward_backward( - logits.chunk(world_size, 1)[rank], - target if target_format == TargetFormat.labels else target.chunk(world_size, 1)[rank], - group=group, - **kwargs, - ) - _compare_entropy_loss_outputs(out, out_ref, True, grad, grad_ref.chunk(world_size, 1)[rank], 1e-4) - - -def _run_entropy_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path): - for entropy_loss_type in EntropyLossType: - for target_format in TargetFormat: - if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - continue - for loss_masking in [False, True]: - name = f"{entropy_loss_type}_{target_format}_{loss_masking}" - with test_context.subtest(base_path, name, 2) as subtest: - if subtest.do_run: - _entropy_loss_distributed(target_format, entropy_loss_type, loss_masking, test_context.group) - - -@pytest.mark.slow -def test_entropy_loss_distributed_dependency(): - # Mock test so the distributed subtest are placed in the same dependency group. - pass - - -@pytest.mark.slow -@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) -def test_run_entropy_loss_distributed(run_parallel_script, result_path): - run_parallel_script( - _run_entropy_loss_distributed, - (result_path / "test_entropy_loss",), - world_size=2, - backend=DistributedBackend.gloo, - use_cuda=False, # Disable device count check. - ) - - -# We don't want to depend on `test_run_entropy_loss_distributed` because we still want to run this in cas of failure. -# This should still run after `test_run_entropy_loss_distributed` -@pytest.mark.slow -@pytest.mark.depends_on(on=["test_entropy_loss_distributed_dependency"]) -@pytest.mark.parametrize("target_format", TargetFormat) -@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) -@pytest.mark.parametrize("loss_masking", (False, True)) -def test_entropy_loss_distributed(result_path, report_subtest, target_format, entropy_loss_type, loss_masking): - if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - pytest.skip(reason="Not implemented") - report_subtest(result_path / f"test_entropy_loss/{entropy_loss_type}_{target_format}_{loss_masking}", 2) diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 840e3846d..6471a516f 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -1,13 +1,10 @@ -import numpy as np import pytest import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.utils import Assert -from tests.utils.dataset import get_random_spans def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans: list[list[tuple[int, int]]]): @@ -18,59 +15,6 @@ def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans ) -def reference_dpo_loss( - logits: torch.Tensor, - targets: torch.Tensor, - reference_model_logits: torch.Tensor, - chosen_spans: torch.Tensor, - rejected_spans: torch.Tensor, - beta: float, -) -> torch.Tensor: - # TODO: Too similar to the actual implementation. - policy_log_probs = ( - torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=targets.unsqueeze(-1)).squeeze(-1) - ) - policy_chosen_logps = sum( - policy_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(chosen_spans) - for begin, end in sample_spans - ) - policy_rejected_logps = sum( - policy_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(rejected_spans) - for begin, end in sample_spans - ) - reference_log_probs = ( - torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) - .gather(dim=-1, index=targets.unsqueeze(-1)) - .squeeze(-1) - ) - reference_chosen_logps = sum( - reference_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(chosen_spans) - for begin, end in sample_spans - ) - reference_rejected_logps = sum( - reference_log_probs[sample_index, begin:end].sum() - for sample_index, sample_spans in enumerate(rejected_spans) - for begin, end in sample_spans - ) - pi_logratios = policy_chosen_logps - policy_rejected_logps - ref_logratios = reference_chosen_logps - reference_rejected_logps - return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() - - -def test_dpo_loss(): - logits = torch.normal(0, 1, (10, 50, 100)) - reference_model_logits = torch.normal(0, 1, (10, 50, 100)) - targets = torch.randint(0, 100, (10, 50)) - spans = get_random_spans(np.full(10, 50), 0, 10) - - fastllm_loss = dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2]) - reference_loss = reference_dpo_loss(logits, targets, reference_model_logits, spans[::2], spans[1::2], beta=1) - Assert.rms_close(fastllm_loss, reference_loss, 1e-5) - - @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( "activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py new file mode 100644 index 000000000..38c69b98e --- /dev/null +++ b/tests/layers/test_lm_losses.py @@ -0,0 +1,346 @@ +import contextlib +import pathlib +import random + +import numpy as np +import pytest +import torch + +from fast_llm.core.ops import split_op +from fast_llm.engine.config_utils import data_type +from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.distributed.config import DistributedBackend +from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.layers.language_model.loss.dpo import dpo_loss +from fast_llm.layers.language_model.loss.entropy_loss import entropy_loss_forward_backward +from fast_llm.layers.language_model.loss.loss import loss_forward_backward +from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward +from fast_llm.utils import Assert +from tests.utils.dataset import get_random_spans +from tests.utils.subtest import DistributedTestContext + +VOCAB_SIZE = 100 +NUM_TOKENS = 200 + + +def _get_lm_loss_inputs( + num_columns: int, loss_masking: bool, target_format: TargetFormat, batch_shape: tuple[int], dtype +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + device = "cuda" if torch.cuda.is_available() else "cpu" + # We want something moderately close to the target for the test to be meaningful + logits_var = torch.randn((*batch_shape, num_columns), dtype=dtype.torch, device=device) / 3 + loss_mask = torch.randint(0, 2, batch_shape, dtype=torch.bool, device=device) if loss_masking else None + if target_format == TargetFormat.labels: + target = torch.randint(0, num_columns, batch_shape, dtype=torch.int64, device=device) + logits = torch.nn.functional.one_hot(target, num_columns) + logits_var + if loss_masking: + target = torch.where(loss_mask, target, -100) + loss_mask = None + else: + # Target logits are typically in training precision, ex. with distillation model. + target = torch.randn((*batch_shape, num_columns), dtype=dtype.torch, device=device) + logits = target + logits_var + if target_format == TargetFormat.probabilities: + # Probabilities need to be in full precision for accuracy. + target = torch.softmax(target, -1, dtype=torch.float32) + return logits, target, loss_mask + + +def _compare_losses_and_grads( + loss: torch.Tensor, + ref_loss: torch.Tensor, + has_grad: bool, + grad: torch.Tensor | None, + ref_grad: torch.Tensor | None, + threshold=1e-5, + group: torch.distributed.ProcessGroup | None = None, +): + Assert.rms_close_relative(loss, ref_loss, threshold, 1e-6) + if has_grad: + Assert.rms_close_relative( + grad, split_op(ref_grad, group, -1), threshold, 1e-8 if grad.dtype == torch.float32 else 1e-7 + ) + else: + assert grad is None + assert ref_grad is None + + +def reference_dpo_loss( + logits: torch.Tensor, + labels: torch.Tensor, + reference_model_logits: torch.Tensor, + chosen_spans: torch.Tensor, + rejected_spans: torch.Tensor, + beta: float, +) -> torch.Tensor: + # TODO: Too similar to the actual implementation. + policy_log_probs = ( + torch.nn.functional.log_softmax(logits.float(), dim=-1).gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + ) + policy_chosen_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + policy_rejected_logps = sum( + policy_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + reference_log_probs = ( + torch.nn.functional.log_softmax(reference_model_logits.float(), dim=-1) + .gather(dim=-1, index=labels.unsqueeze(-1)) + .squeeze(-1) + ) + reference_chosen_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(chosen_spans) + for begin, end in sample_spans + ) + reference_rejected_logps = sum( + reference_log_probs[sample_index, begin:end].sum() + for sample_index, sample_spans in enumerate(rejected_spans) + for begin, end in sample_spans + ) + pi_logratios = policy_chosen_logps - policy_rejected_logps + ref_logratios = reference_chosen_logps - reference_rejected_logps + return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() + + +_BATCH_SHAPES = ((64,), (16, 8)) +_LOSS_PARAMETERS = ( + (500, 1.0, 1.0, False, DataType.float32), # Simple + (512, 1.0, 1.0, False, DataType.float32), # Power of 2 + (500, None, 1.0, False, DataType.float32), # No grad + (500, 1.0, 4.0, False, DataType.float32), # Loss scaling + (500, 4.0, 1.0, False, DataType.float32), # Grad scaling + (500, 1.0, 1.0, True, DataType.float32), # Loss masking + (500, 1.0, 1.0, True, DataType.float16), # Fp16 + (500, 1.0, 1.0, True, DataType.bfloat16), # Bf16 + (65538, 1.0, 1.0, False, DataType.float32), # Above max block size +) + + +def _test_entropy_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + target_format, + entropy_loss_type, + dtype, + group=None, +): + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + pytest.skip(reason="Not implemented") + # TODO: Test tensor-parallel implementation. + logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, target_format, batch_shape, dtype) + # Torch serves as the reference implementation. + out_ref, grad_ref = entropy_loss_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + entropy_loss_type=entropy_loss_type, + implementation=EntropyLossImplementation.torch, + ) + out_fused, grad_fused = entropy_loss_forward_backward( + logits=split_op(logits, group, -1), + target=target if target_format == TargetFormat.labels else split_op(target, group, -1), + loss_mask=loss_mask, + grad_output=grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + entropy_loss_type=entropy_loss_type, + implementation=EntropyLossImplementation.fused, + ) + + _compare_losses_and_grads( + out_fused, + out_ref, + grad_output is not None, + grad_fused, + grad_ref, + threshold=1e-5 if data_type == DataType.float32 else 1e-4, + group=group, + ) + + if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available() or group is not None: + # Triton implementation only supports cross-entropy. + return + assert TritonConfig.TRITON_ENABLED + with pytest.raises(AssertionError) if num_columns > 65536 else contextlib.nullcontext(): + out_triton, grad_triton = entropy_loss_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + entropy_loss_type=entropy_loss_type, + implementation=EntropyLossImplementation.triton, + ) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) + + +def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): + logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, TargetFormat.logits, batch_shape, dtype) + out_ref, grad_ref = loss_forward_backward( + grad_output, + z_loss, + logits, + loss_mask, + logits_scale_factor=logits_scale_factor, + ) + out_fused, grad_fused = z_loss_forward_backward( + split_op(logits, group, -1), + loss_mask, + grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + ) + _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + + +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS +) +@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("entropy_loss_type", EntropyLossType) +def test_entropy_loss( + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type, dtype +): + _test_entropy_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + target_format, + entropy_loss_type, + dtype, + ) + + +@pytest.mark.slow +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS +) +def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype): + _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype) + + +@pytest.mark.skip(reason="DPO loss is broken") +def test_dpo_loss(): + logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) + reference_model_logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) + labels = torch.randint(0, VOCAB_SIZE, (NUM_TOKENS,)) + spans = get_random_spans(np.full(10, 50), 0, 10) + + fast_llm_loss = dpo_loss(logits, labels, reference_model_logits, spans[::2], spans[1::2]) + reference_loss = reference_dpo_loss(logits, labels, reference_model_logits, spans[::2], spans[1::2], beta=1) + Assert.rms_close(fast_llm_loss, reference_loss, 1e-5) + + +def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path, seed: int): + for batch_shape in _BATCH_SHAPES: + for num_columns, grad_output, logits_scale_factor, loss_masking, dtype in _LOSS_PARAMETERS: + suffix = f"{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{"_".join([str(i) for i in batch_shape])}" + # Entropy loss + for entropy_loss_type in EntropyLossType: + for target_format in TargetFormat: + if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: + continue + with test_context.subtest( + base_path, f"{entropy_loss_type}-{target_format}-{suffix}", 2 + ) as subtest: + if subtest.do_run: + torch.manual_seed((seed + hash(subtest.name)) % 2**32) + _test_entropy_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + target_format, + entropy_loss_type, + dtype, + test_context.group, + ) + # Z loss + with test_context.subtest(base_path, f"z_loss-{suffix}", 2) as subtest: + if subtest.do_run: + torch.manual_seed((seed + hash(subtest.name)) % 2**32) + _test_z_loss( + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, + test_context.group, + ) + + +@pytest.mark.slow +def test_lm_loss_distributed_dependency(): + # Mock test so the distributed subtest are placed in the same dependency group. + pass + + +# We don't want to depend on `test_run_entropy_loss_distributed` because we still want to run this in cas of failure. +# This should still run after `test_run_entropy_loss_distributed` +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_lm_loss_distributed_dependency"]) +def test_run_lm_loss_distributed(run_parallel_script, result_path): + run_parallel_script( + _run_lm_loss_distributed, + (result_path / "test_losses", random.randint(0, 2**32 - 1)), + world_size=2, + backend=DistributedBackend.gloo, + use_cuda=False, # Disable device count check. + ) + + +@pytest.mark.slow +@pytest.mark.depends_on(on=["test_lm_loss_distributed_dependency"]) +@pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) +@pytest.mark.parametrize( + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS +) +@pytest.mark.parametrize( + "loss_type", + ( + *( + f"{entropy_loss_type}-{target_format}" + for entropy_loss_type in EntropyLossType + for target_format in TargetFormat + if target_format != TargetFormat.labels or entropy_loss_type != EntropyLossType.reverse_kl + ), + "z_loss", + ), +) +def test_lm_loss_distributed( + result_path, + report_subtest, + loss_type, + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, +): + report_subtest( + result_path + / f"test_losses/{loss_type}-{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{"_".join([str(i) for i in batch_shape])}", + 2, + use_cuda=False, + ) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 955fa534c..4741699be 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -479,10 +479,6 @@ def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config = distributed_save_load_config.resolve( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) - if torch.cuda.device_count() < distributed_save_load_config.num_gpus: - pytest.skip( - f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}" - ) report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus) load_and_compare_checkpoints( DistributedCheckpointFormat, diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 0c58afade..df3e52b7a 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -94,8 +94,6 @@ def test_model_distributed( config = DISTRIBUTED_TESTING_CONFIGS[config_name] if model_testing_config.should_skip(config): pytest.skip(f"Configuration not supported.") - if torch.cuda.device_count() < config.num_gpus: - pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < {config.num_gpus}") report_subtest(run_test_script_base_path / config.name, config.num_gpus) if config.compare is not None: if not check_subtest_success(run_test_script_base_path / config.compare): diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 854ecec36..4ad122947 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -322,9 +322,10 @@ def get_model_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset", seed=1234, + num_documents=200, max_loss_masking_spans=5, max_vocab_size=MODEL_TEST_VOCAB_SIZE, - splits={"training": 969, "validation": 30, "test": 1}, + splits={"training": 180, "validation": 19, "test": 1}, config_only=config_only, ) @@ -333,6 +334,7 @@ def get_multimodal_test_dataset(config_only: bool = False): return _get_test_dataset( DATASET_CACHE / "model_dataset_multimodal", seed=1234, + num_documents=200, max_vocab_size=MODEL_TEST_VOCAB_SIZE, max_images=2, image_patch_config=ImagePatchConfig( @@ -343,6 +345,6 @@ def get_multimodal_test_dataset(config_only: bool = False): image_break_token=None, image_end_token=None, ), - splits={"training": 969, "validation": 30, "test": 1}, + splits={"training": 180, "validation": 19, "test": 1}, config_only=config_only, ) diff --git a/tests/utils/subtest.py b/tests/utils/subtest.py index b6764c0e2..b8f0b5b7a 100644 --- a/tests/utils/subtest.py +++ b/tests/utils/subtest.py @@ -51,12 +51,12 @@ def __enter__(self): self._configure_logging() self._group = self._pool.get_process_group(range(self._world_size), self._rank) # TODO: Barriers needed? - safe_barrier(self._group, "start", device=self._pool.device) + safe_barrier(self._group, "start") return self def __exit__(self, exc_type, exc_val, exc_tb): # Final barrier to ensure everything is done before torchrun potentially kills workers. - safe_barrier(self._group, "testing end", device=self._pool.device) + safe_barrier(self._group, "testing end") # Let pytest know how things went. # These should already be reported above, we repeat for convenience. if self._failures: @@ -138,13 +138,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): if (group := self._test_context._group) is not None: # Barrier so `allreduce_scalar` doesn't go crazy in case of desync. - safe_barrier(group, self._name, device=self._test_context._pool.device) - self._success = ( - allreduce_scalar( - self._success, dtype=torch.int64, group=group, device=self._test_context._pool.device - ) - == group.size() - ) + safe_barrier(group, self._name) + self._success = allreduce_scalar(self._success, dtype=torch.int64, group=group) == group.size() if self._do_capture and torch.cuda.is_available(): # Free resources to limit memory usage. @@ -165,6 +160,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): def do_run(self) -> bool: return self._do_run and not self._skip + @property + def name(self) -> str: + return self._name + def set_subtest_success(path: pathlib.Path, success: bool = True): path.joinpath("pytest_success").write_text(str(int(success))) @@ -201,7 +200,9 @@ def report_subtest(request: pytest.FixtureRequest): verbose = request.config.getoption("verbose") do_capture = request.config.getoption("distributed_capture") - def do_report_subtest(path: pathlib.Path, world_size: int) -> None: + def do_report_subtest(path: pathlib.Path, world_size: int, use_cuda: bool = True) -> None: + if use_cuda and torch.cuda.device_count() < world_size: + pytest.skip(f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {world_size}") success = check_subtest_success(path) if not do_capture: logger.warning("Distributed capture is disabled. See distributed test for run output.") From 741832cef01b28fb7e75bf3545b44f1ec7dd8789 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 29 Jan 2026 16:22:44 -0500 Subject: [PATCH 02/22] fix --- tests/models/test_checkpoint.py | 4 ++++ tests/models/test_model.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 4741699be..955fa534c 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -479,6 +479,10 @@ def test_load_parallel_checkpoint_in_single_gpu( distributed_save_load_config = distributed_save_load_config.resolve( base_path=run_test_script_base_path, model_testing_config=model_testing_config ) + if torch.cuda.device_count() < distributed_save_load_config.num_gpus: + pytest.skip( + f"Not enough GPUs to run dependency: {torch.cuda.device_count()} < {distributed_save_load_config.num_gpus}" + ) report_subtest(distributed_save_load_config.save_path, distributed_save_load_config.num_gpus) load_and_compare_checkpoints( DistributedCheckpointFormat, diff --git a/tests/models/test_model.py b/tests/models/test_model.py index df3e52b7a..0c58afade 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -94,6 +94,8 @@ def test_model_distributed( config = DISTRIBUTED_TESTING_CONFIGS[config_name] if model_testing_config.should_skip(config): pytest.skip(f"Configuration not supported.") + if torch.cuda.device_count() < config.num_gpus: + pytest.skip(f"Not enough GPUs: {torch.cuda.device_count()} < {config.num_gpus}") report_subtest(run_test_script_base_path / config.name, config.num_gpus) if config.compare is not None: if not check_subtest_success(run_test_script_base_path / config.compare): From b1d4b8df88d595cc21db899f23e47c380e42ffc2 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 29 Jan 2026 16:41:51 -0500 Subject: [PATCH 03/22] fix --- fast_llm/functional/entropy_loss.py | 16 ++++++++-------- tests/layers/test_lm_losses.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 4bdf87c3f..d56c745ae 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -46,7 +46,7 @@ def torch_entropy_loss_forward_backward( if entropy_loss_type == EntropyLossType.cross_entropy: if target_format == TargetFormat.logits: target = torch.softmax(target, dim=-1) - loss = torch.nn.functional.cross_entropy(logits_scaled, target, reduction="none") + per_sample_loss = torch.nn.functional.cross_entropy(logits_scaled, target, reduction="none") else: predicted_log_probability = torch.nn.functional.log_softmax(logits_scaled, dim=-1) if target_format == TargetFormat.logits: @@ -61,14 +61,14 @@ def torch_entropy_loss_forward_backward( target_probability = target_probability * loss_mask.unsqueeze(-1) target_log_probability = target_probability.add(1.0e-10).log() if entropy_loss_type == EntropyLossType.forward_kl: - loss = torch.nn.functional.kl_div( + per_sample_loss = torch.nn.functional.kl_div( predicted_log_probability, target_log_probability, reduction="none", log_target=True, ) elif entropy_loss_type == EntropyLossType.reverse_kl: - loss = torch.nn.functional.kl_div( + per_sample_loss = torch.nn.functional.kl_div( target_log_probability, predicted_log_probability, reduction="none", @@ -76,11 +76,11 @@ def torch_entropy_loss_forward_backward( ) else: raise NotImplementedError(entropy_loss_type) - loss = loss.sum(dim=-1) + per_sample_loss = per_sample_loss.sum(dim=-1) if loss_mask is not None: - loss = loss * loss_mask - loss = loss.mean() + per_sample_loss = per_sample_loss * loss_mask + loss = per_sample_loss.mean() if grad_output is None: grad = None @@ -145,7 +145,7 @@ def _fused_reverse_kl_base( # Compute loss terms: student_probs * log_ratio, then sum over vocab # This is equivalent to kl_div(..., log_target=True) but more memory efficient log_ratio = predicted_log_probability - target_log_probability - per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1, keepdim=True) + per_sample_loss = (predicted_probability * log_ratio).sum(dim=-1) if group is not None: all_reduce(per_sample_loss, op=ReduceOp.SUM, group=group) @@ -154,7 +154,7 @@ def _fused_reverse_kl_base( else: # Gradient: d/d(logits) KL(q||p) = q * (log(q/p) - E_q[log(q/p)]) # where E_q[log(q/p)] is the expected log ratio under the student distribution - grad = (log_ratio - per_sample_loss) * predicted_probability * grad_output + grad = (log_ratio - per_sample_loss.unsqueeze(-1)) * predicted_probability * grad_output return per_sample_loss, grad diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 38c69b98e..639a3ba7c 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -115,8 +115,8 @@ def reference_dpo_loss( (500, 1.0, 4.0, False, DataType.float32), # Loss scaling (500, 4.0, 1.0, False, DataType.float32), # Grad scaling (500, 1.0, 1.0, True, DataType.float32), # Loss masking - (500, 1.0, 1.0, True, DataType.float16), # Fp16 - (500, 1.0, 1.0, True, DataType.bfloat16), # Bf16 + (500, 1.0, 1.0, False, DataType.float16), # Fp16 + (500, 1.0, 1.0, True, DataType.bfloat16), # Bf16, loss masking (65538, 1.0, 1.0, False, DataType.float32), # Above max block size ) From c326bffceae8d118ccc7d1bb502de1ebe188e131 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 29 Jan 2026 22:49:24 -0500 Subject: [PATCH 04/22] fix --- tests/conftest.py | 2 +- tests/utils/dataset.py | 7 +++++-- tests/utils/model_configs.py | 2 +- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index f93eec215..4f7d7bad0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -256,7 +256,7 @@ def pytest_runtest_call(item: pytest.Function): if torch.cuda.is_available(): # Empty cache to check is cuda is still working (TODO: Is there a better way? Can we kill the worker?) try: - torch.cuda.empty_cache() + torch.cuda.synchronize() except RuntimeError: pytest.skip("Cuda runtime unavailable due to an error in an earlier test.") manager.handle_missing(item) diff --git a/tests/utils/dataset.py b/tests/utils/dataset.py index 4ad122947..d1b627ecc 100644 --- a/tests/utils/dataset.py +++ b/tests/utils/dataset.py @@ -10,6 +10,7 @@ from fast_llm.data.preprocessing.abstract import NullPreprocessingConfig from fast_llm.data.preprocessing.image_patch import ImagePatchConfig from fast_llm.data.preprocessing.language_model import LanguageModelPreprocessingConfig +from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.utils import padded_cumsum from tests.utils.global_variables import DATASET_CACHE, MODEL_TEST_VOCAB_SIZE, TOKENIZER_FILE, TOKENIZER_PATH @@ -184,6 +185,8 @@ def _get_test_dataset( hf_path = path / "hf" if not config_only and not all(config_path.is_file() for config_path in config_paths): + # Not supported for parallel tests, but dataset should already exist anyway. + assert DistributedConfig.default_world_size == 1 dataset = _get_hf_test_dataset( seed=seed, num_documents=num_documents, @@ -325,7 +328,7 @@ def get_model_test_dataset(config_only: bool = False): num_documents=200, max_loss_masking_spans=5, max_vocab_size=MODEL_TEST_VOCAB_SIZE, - splits={"training": 180, "validation": 19, "test": 1}, + splits={"training": 180, "validation": 18, "test": 2}, config_only=config_only, ) @@ -345,6 +348,6 @@ def get_multimodal_test_dataset(config_only: bool = False): image_break_token=None, image_end_token=None, ), - splits={"training": 180, "validation": 19, "test": 1}, + splits={"training": 180, "validation": 18, "test": 2}, config_only=config_only, ) diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 5e7526377..5a6aff831 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -986,7 +986,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.not_implemented, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, - compare_factor=10.0, # similar to gdn with compare_factor 2 fails fp16 and bf16 tests in the normalizaiton layer when using rms_norm_gated from fla + compare_factor=15.0, # similar to gdn with compare_factor 2 fails fp16 and bf16 tests in the normalization layer when using rms_norm_gated from fla # note: tp is excluded because there is currently no gradient reductions implemented for tp norm in gdn.py (STP works though). # we should be using STP with this model, not TP! skip_tests=("sdp", "ms", TP_NO_STP), From 02c28a50103a38ced74593fc0e0f48121c63acde Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 30 Jan 2026 22:30:48 -0500 Subject: [PATCH 05/22] Triton loss --- fast_llm/functional/triton/cross_entropy.py | 2 +- tests/layers/test_lm_losses.py | 24 ++++++++++----------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index ef2039ade..a8becfb68 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -128,6 +128,7 @@ def triton_cross_entropy_forward_backward( target_format: TargetFormat, entropy_loss_type: EntropyLossType, temperature: float = 1.0, + looped: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -143,7 +144,6 @@ def triton_cross_entropy_forward_backward( n_rows = logits.shape[:-1].numel() n_cols = logits.size(-1) block_size = triton.next_power_of_2(n_cols) - assert block_size <= TritonConfig.MAX_BLOCK_SIZE_BYTES num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # TODO: Safe to do inplace? diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 639a3ba7c..1a31db90a 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -1,4 +1,3 @@ -import contextlib import pathlib import random @@ -173,18 +172,17 @@ def _test_entropy_loss( # Triton implementation only supports cross-entropy. return assert TritonConfig.TRITON_ENABLED - with pytest.raises(AssertionError) if num_columns > 65536 else contextlib.nullcontext(): - out_triton, grad_triton = entropy_loss_forward_backward( - logits=logits, - target=target, - loss_mask=loss_mask, - grad_output=grad_output, - logits_scale_factor=logits_scale_factor, - target_format=target_format, - entropy_loss_type=entropy_loss_type, - implementation=EntropyLossImplementation.triton, - ) - _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) + out_triton, grad_triton = entropy_loss_forward_backward( + logits=logits, + target=target, + loss_mask=loss_mask, + grad_output=grad_output, + logits_scale_factor=logits_scale_factor, + target_format=target_format, + entropy_loss_type=entropy_loss_type, + implementation=EntropyLossImplementation.triton, + ) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): From 22c8e0ba0935a63db086042a18f824903c8f4f08 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 30 Jan 2026 22:33:49 -0500 Subject: [PATCH 06/22] Triton loss --- fast_llm/functional/triton/cross_entropy.py | 1 - 1 file changed, 1 deletion(-) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index a8becfb68..354223a96 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -128,7 +128,6 @@ def triton_cross_entropy_forward_backward( target_format: TargetFormat, entropy_loss_type: EntropyLossType, temperature: float = 1.0, - looped: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, From 7491e0f8bf0619a5ec647d4af858589a191f4761 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 30 Jan 2026 23:39:32 -0500 Subject: [PATCH 07/22] Parallel attempt --- fast_llm/functional/entropy_loss.py | 3 +- fast_llm/functional/triton/cross_entropy.py | 62 ++++++++++++++++++- .../language_model/loss/entropy_loss.py | 35 ++++------- 3 files changed, 73 insertions(+), 27 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index d56c745ae..37486ddc0 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -14,6 +14,7 @@ def torch_entropy_loss_forward_backward( logits_scale_factor: float, target_format: TargetFormat, entropy_loss_type: EntropyLossType, + group: ProcessGroup | None = None, temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: # (), (*batch, vocab) """ @@ -21,7 +22,7 @@ def torch_entropy_loss_forward_backward( The cross-entropy kernels themselves are well-optimized, but the need for explicit casting and separate forward and backward kernels lead to poor performance. """ - + assert group is None # Torch methods require flattened batch dimension. target = target.flatten() if target_format == TargetFormat.labels else target.flatten(0, -2) if target_format == TargetFormat.labels: diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 354223a96..12d96a881 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -5,12 +5,42 @@ from fast_llm.utils import Assert +@triton_jit() +def triton_softmax_base_kernel( + logits_ptr, + max_logits_ptr, + sum_exp_logits_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + logits_scale_factor: tl_constexpr, + block_size: tl_constexpr, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + col_offsets = tl.arange(0, block_size) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + mask = col_offsets < n_cols + + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + + tl.store(max_logits_ptr + block_idx, max_logits) + tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) + + @triton_jit() def triton_cross_entropy_forward_backward_kernel( logits_ptr, labels_ptr, grad_logits_ptr, losses_ptr, + max_logits_ptr, + sum_exp_logits_ptr, grad_losses, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, @@ -28,9 +58,15 @@ def triton_cross_entropy_forward_backward_kernel( if logits_scale_factor != 1.0: logits *= logits_scale_factor - max_logits = tl.max(logits, 0) + if max_logits_ptr is None: + max_logits = tl.max(logits, 0) + else: + max_logits = tl.load(max_logits_ptr + block_idx) exp_logits = tl.exp(logits - max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) + if sum_exp_logits_ptr is None: + sum_exp_logits = tl.sum(exp_logits, 0) + else: + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) label_idx = tl.load(labels_ptr + block_idx) @@ -127,6 +163,7 @@ def triton_cross_entropy_forward_backward( logits_scale_factor: float, target_format: TargetFormat, entropy_loss_type: EntropyLossType, + group: torch.distributed.ProcessGroup | None = None, temperature: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor]: """ @@ -148,11 +185,31 @@ def triton_cross_entropy_forward_backward( # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) if target_format == TargetFormat.labels: + if group is None: + max_logits = sum_exp_logits = None + else: + local_max_logits = torch.empty_like(losses) + sum_exp_logits = torch.empty_like(losses) + triton_softmax_base_kernel[(n_rows,)]( + logits, + local_max_logits, + sum_exp_logits, + n_cols, + logits.stride(-2), + logits_scale_factor, + block_size=block_size, + ) + max_logits = local_max_logits.clone() + torch.distributedall_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) + sum_exp_logits = sum_exp_logits * (local_max_logits - max_logits).exp() + torch.distributedall_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) triton_cross_entropy_forward_backward_kernel[(n_rows,)]( logits, target, grad_logits, losses, + max_logits, + sum_exp_logits, None if grad_output is None else grad_output / n_rows, n_cols, logits.stride(-2), @@ -162,6 +219,7 @@ def triton_cross_entropy_forward_backward( num_warps=num_warps, ) else: + assert group is None if loss_mask is not None: assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 1dfd3920c..550f8f330 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -122,27 +122,14 @@ def entropy_loss_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - if group: - Assert.eq(implementation, EntropyLossImplementation.fused) - return fused_entropy_loss_forward_backward( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - group, - temperature, - ) - else: - return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - temperature=temperature, - ) + return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( + logits, + target, + loss_mask, + grad_output, + logits_scale_factor, + target_format, + entropy_loss_type, + group, + temperature=temperature, + ) From b8e7179976f49c64947c177d93293c3786e013b0 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 3 Feb 2026 00:26:45 -0500 Subject: [PATCH 08/22] fix --- fast_llm/functional/triton/cross_entropy.py | 145 ++++++++++++++------ 1 file changed, 105 insertions(+), 40 deletions(-) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 12d96a881..22498cf48 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -6,10 +6,13 @@ @triton_jit() -def triton_softmax_base_kernel( +def triton_cross_entropy_forward_parallel_kernel( logits_ptr, + labels_ptr, max_logits_ptr, sum_exp_logits_ptr, + predicted_logits_ptr, + col_min: tl_constexpr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, logits_scale_factor: tl_constexpr, @@ -29,6 +32,17 @@ def triton_softmax_base_kernel( exp_logits = tl.exp(logits - max_logits) sum_exp_logits = tl.sum(exp_logits, 0) + if labels_ptr is not None and predicted_logits_ptr is not None: + label_idx = tl.load(labels_ptr + block_idx) - col_min + if label_idx < 0 or label_idx >= n_cols: + # Loss mask + predicted_logits = 0.0 + else: + predicted_logits = tl.load(logits_ptr + label_idx).to(tl.float32) + if logits_scale_factor != 1.0: + predicted_logits *= logits_scale_factor + tl.store(predicted_logits_ptr + block_idx, predicted_logits) + tl.store(max_logits_ptr + block_idx, max_logits) tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) @@ -42,6 +56,7 @@ def triton_cross_entropy_forward_backward_kernel( max_logits_ptr, sum_exp_logits_ptr, grad_losses, + col_min: tl_constexpr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, grad_logits_stride_0: tl_constexpr, @@ -68,26 +83,32 @@ def triton_cross_entropy_forward_backward_kernel( else: sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) - label_idx = tl.load(labels_ptr + block_idx) + label_idx = tl.load(labels_ptr + block_idx) - col_min - if label_idx < 0: - # Loss mask - loss = 0.0 - else: - label_logits = tl.load(logits_ptr + label_idx).to(tl.float32) - if logits_scale_factor != 1.0: - label_logits *= logits_scale_factor - loss = tl.log(sum_exp_logits) + max_logits - label_logits - tl.store(losses_ptr + block_idx, loss) + if losses_ptr is not None: + if label_idx < 0 or label_idx >= n_cols: + # Loss mask + loss = 0.0 + else: + predicted_logits = tl.load(logits_ptr + label_idx).to(tl.float32) + if logits_scale_factor != 1.0: + predicted_logits *= logits_scale_factor + loss = tl.log(sum_exp_logits) + max_logits - predicted_logits + tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: - if label_idx < 0: + if label_idx < -col_min: grad_losses = 0.0 + elif logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor grad_base = exp_logits / sum_exp_logits - grad_logits = grad_losses * tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) - if logits_scale_factor != 1.0: - grad_logits *= logits_scale_factor - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + if label_idx < 0 or label_idx >= n_cols: + grad_logits = grad_base + else: + grad_logits = tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits * grad_losses, mask=mask + ) @triton_jit() @@ -155,6 +176,25 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) +@torch.compile +def _rescale_sum_exp_logits( + sum_exp_logits: torch.Tensor, + local_max_logits: torch.Tensor, + max_logits: torch.Tensor, +) -> torch.Tensor: + return sum_exp_logits * (local_max_logits - max_logits).exp() + + +@torch.compile +def _calculate_loss( + predicted_logits: torch.Tensor, + target: torch.Tensor, + sum_exp_logits: torch.Tensor, + max_logits: torch.Tensor, +) -> torch.Tensor: + return torch.where(target.flatten() >= 0, sum_exp_logits.log() + max_logits - predicted_logits, 0).mean() + + def triton_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -181,45 +221,69 @@ def triton_cross_entropy_forward_backward( n_cols = logits.size(-1) block_size = triton.next_power_of_2(n_cols) num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) - losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) if target_format == TargetFormat.labels: if group is None: - max_logits = sum_exp_logits = None + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + logits, + target, + grad_logits, + losses, + None, + None, + None if grad_output is None else grad_output / n_rows, + 0, + n_cols, + logits.stride(-2), + None if grad_output is None else grad_logits.stride(-2), + logits_scale_factor, + block_size=block_size, + num_warps=num_warps, + ) + loss = losses.mean() else: - local_max_logits = torch.empty_like(losses) - sum_exp_logits = torch.empty_like(losses) - triton_softmax_base_kernel[(n_rows,)]( + predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(predicted_logits) + sum_exp_logits = torch.empty_like(predicted_logits) + triton_cross_entropy_forward_parallel_kernel[(n_rows,)]( logits, + target, local_max_logits, sum_exp_logits, + predicted_logits, + n_cols * group.rank(), n_cols, logits.stride(-2), logits_scale_factor, block_size=block_size, ) max_logits = local_max_logits.clone() - torch.distributedall_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) - sum_exp_logits = sum_exp_logits * (local_max_logits - max_logits).exp() - torch.distributedall_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) - triton_cross_entropy_forward_backward_kernel[(n_rows,)]( - logits, - target, - grad_logits, - losses, - max_logits, - sum_exp_logits, - None if grad_output is None else grad_output / n_rows, - n_cols, - logits.stride(-2), - None if grad_output is None else grad_logits.stride(-2), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, - ) + torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) + sum_exp_logits = _rescale_sum_exp_logits(sum_exp_logits, local_max_logits, max_logits) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) + loss = _calculate_loss(predicted_logits, target, sum_exp_logits, max_logits) + triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + logits, + target, + grad_logits, + None, + max_logits, + sum_exp_logits, + None if grad_output is None else grad_output / n_rows, + n_cols * group.rank(), + n_cols, + logits.stride(-2), + None if grad_output is None else grad_logits.stride(-2), + logits_scale_factor, + block_size=block_size, + num_warps=num_warps, + ) else: assert group is None + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) if loss_mask is not None: assert loss_mask.is_contiguous() triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( @@ -238,4 +302,5 @@ def triton_cross_entropy_forward_backward( num_warps=num_warps, from_logits=target_format == TargetFormat.logits, ) - return losses.mean(), grad_logits + loss = losses.mean() + return loss, grad_logits From 3c3e0c8704d14fdfe77d41370e70a6b0b1242bce Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Tue, 3 Feb 2026 15:45:19 -0500 Subject: [PATCH 09/22] fixes --- fast_llm/functional/triton/cross_entropy.py | 97 ++++++++++++------- .../language_model/loss/entropy_loss.py | 2 + 2 files changed, 64 insertions(+), 35 deletions(-) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 22498cf48..82312b99e 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -5,6 +5,32 @@ from fast_llm.utils import Assert +@triton_jit() +def triton_fused_softmax_base( + logits_ptr, + n_cols: tl_constexpr, + logits_scale_factor: tl_constexpr, + block_size: tl_constexpr, +): + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl.arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + if col_offset == 0: + max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + else: + new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) + exp_logits = tl.exp(logits - new_max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) + max_logits = new_max_logits + return exp_logits, sum_exp_logits, max_logits, mask + + @triton_jit() def triton_cross_entropy_forward_parallel_kernel( logits_ptr, @@ -20,17 +46,11 @@ def triton_cross_entropy_forward_parallel_kernel( ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) - col_offsets = tl.arange(0, block_size) logits_ptr = logits_ptr + block_idx * logits_stride_0 - mask = col_offsets < n_cols - - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - max_logits = tl.max(logits, 0) - exp_logits = tl.exp(logits - max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) + exp_logits, sum_exp_logits, max_logits, mask = triton_fused_softmax_base( + logits_ptr, n_cols, logits_scale_factor, block_size + ) if labels_ptr is not None and predicted_logits_ptr is not None: label_idx = tl.load(labels_ptr + block_idx) - col_min @@ -65,22 +85,14 @@ def triton_cross_entropy_forward_backward_kernel( ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) - col_offsets = tl.arange(0, block_size) logits_ptr = logits_ptr + block_idx * logits_stride_0 - mask = col_offsets < n_cols - - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor - if max_logits_ptr is None: - max_logits = tl.max(logits, 0) + if max_logits_ptr is None or sum_exp_logits_ptr is None: + exp_logits, sum_exp_logits, max_logits, mask = triton_fused_softmax_base( + logits_ptr, n_cols, logits_scale_factor, block_size + ) else: max_logits = tl.load(max_logits_ptr + block_idx) - exp_logits = tl.exp(logits - max_logits) - if sum_exp_logits_ptr is None: - sum_exp_logits = tl.sum(exp_logits, 0) - else: sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) label_idx = tl.load(labels_ptr + block_idx) - col_min @@ -89,6 +101,7 @@ def triton_cross_entropy_forward_backward_kernel( if label_idx < 0 or label_idx >= n_cols: # Loss mask loss = 0.0 + predicted_logits = 0.0 else: predicted_logits = tl.load(logits_ptr + label_idx).to(tl.float32) if logits_scale_factor != 1.0: @@ -97,18 +110,28 @@ def triton_cross_entropy_forward_backward_kernel( tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: - if label_idx < -col_min: - grad_losses = 0.0 - elif logits_scale_factor != 1.0: - grad_losses *= logits_scale_factor - grad_base = exp_logits / sum_exp_logits - if label_idx < 0 or label_idx >= n_cols: - grad_logits = grad_base - else: - grad_logits = tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) - tl.store( - grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits * grad_losses, mask=mask - ) + # Run in reverse order to maximize input and cache reuse. + for col_offset in tl.static_range((n_cols - 1) // block_size * block_size, -1, -block_size): + if max_logits_ptr is None or sum_exp_logits_ptr is None or col_offset != n_cols - block_size: + col_offsets = tl.arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + exp_logits = tl.exp(logits - max_logits) + + if label_idx < -col_min: + grad_losses = 0.0 + elif logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor + grad_base = exp_logits / sum_exp_logits + if label_idx < 0 or label_idx >= n_cols: + grad_logits = grad_base + else: + grad_logits = tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits * grad_losses, mask=mask + ) @triton_jit() @@ -205,6 +228,8 @@ def triton_cross_entropy_forward_backward( entropy_loss_type: EntropyLossType, group: torch.distributed.ProcessGroup | None = None, temperature: float = 1.0, + block_size: int | None = None, + num_warps: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ A fast triton implementation of cross-entropy, which combines the casting and forward and backward passes, @@ -219,8 +244,10 @@ def triton_cross_entropy_forward_backward( assert target.is_contiguous() n_rows = logits.shape[:-1].numel() n_cols = logits.size(-1) - block_size = triton.next_power_of_2(n_cols) - num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + if block_size is None: + block_size = min(triton.next_power_of_2(n_cols), 32768) + if num_warps is None: + num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) if target_format == TargetFormat.labels: diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 550f8f330..351aa210b 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -106,6 +106,7 @@ def entropy_loss_forward_backward( temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, + **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ Select the appropriate implementation of cross-entropy. @@ -132,4 +133,5 @@ def entropy_loss_forward_backward( entropy_loss_type, group, temperature=temperature, + **kwargs, ) From 2d293eaf7dc5782c2c3c204d3fe7f8c1747bd138 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 5 Feb 2026 02:51:31 -0500 Subject: [PATCH 10/22] Cross-entropy from distribution --- fast_llm/functional/entropy_loss.py | 8 +- fast_llm/functional/triton/__init__.py | 17 + fast_llm/functional/triton/cross_entropy.py | 509 ++++++++++++++------ tests/layers/test_lm_losses.py | 89 ++-- 4 files changed, 444 insertions(+), 179 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 37486ddc0..25e1ae317 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -121,7 +121,7 @@ def fused_softmax_base( @torch.compile -def _fused_reverse_kl_base( +def _fused_reverse_kl_base_from_distribution( logits: torch.Tensor, # (*batch, vocab) target: torch.Tensor, # (*batch, vocab) grad_output: float | None, @@ -161,7 +161,7 @@ def _fused_reverse_kl_base( @torch.compile -def _fused_cross_entropy_base( +def _fused_cross_entropy_base_from_distribution( logits: torch.Tensor, # (*batch, vocab) target: torch.Tensor, # (*batch, vocab) grad_output: float | None, @@ -302,7 +302,7 @@ def fused_entropy_loss_forward_backward( group, ) elif entropy_loss_type in (EntropyLossType.cross_entropy, EntropyLossType.forward_kl): - per_sample_loss, grad = _fused_cross_entropy_base( + per_sample_loss, grad = _fused_cross_entropy_base_from_distribution( logits, target, grad_output, @@ -313,7 +313,7 @@ def fused_entropy_loss_forward_backward( return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, ) elif entropy_loss_type == EntropyLossType.reverse_kl: - per_sample_loss, grad = _fused_reverse_kl_base( + per_sample_loss, grad = _fused_reverse_kl_base_from_distribution( logits, target, grad_output, diff --git a/fast_llm/functional/triton/__init__.py b/fast_llm/functional/triton/__init__.py index 778559db1..82f67621e 100644 --- a/fast_llm/functional/triton/__init__.py +++ b/fast_llm/functional/triton/__init__.py @@ -1,16 +1,33 @@ +import torch + from fast_llm.utils import InvalidObject, try_decorate try: import triton + import triton.knobs import triton.language as tl tl_constexpr = tl.constexpr TritonConfig = triton.Config + triton_available = torch.cuda.is_available() or triton.knobs.runtime.interpret except ImportError as e: triton = InvalidObject(e) tl = triton tl_constexpr = None TritonConfig = lambda *args, **kwargs: None + triton_available = False triton_jit = try_decorate(lambda: triton.jit) triton_autotune = try_decorate(lambda: triton.autotune) + + +if not triton_available: + tl_arange = None +elif triton.knobs.runtime.interpret: + # Workaround for a triton bug. + @triton_jit + def tl_arange(start, end): + return tl.arange(int(start), int(end)) + +else: + tl_arange = tl.arange diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 82312b99e..6fb8e930d 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,7 +1,7 @@ import torch from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit from fast_llm.utils import Assert @@ -9,11 +9,11 @@ def triton_fused_softmax_base( logits_ptr, n_cols: tl_constexpr, - logits_scale_factor: tl_constexpr, block_size: tl_constexpr, + logits_scale_factor: tl_constexpr = 1.0, ): for col_offset in tl.static_range(0, n_cols, block_size): - col_offsets = tl.arange(col_offset, col_offset + block_size) + col_offsets = tl_arange(col_offset, col_offset + block_size) mask = col_offsets < n_cols logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) if logits_scale_factor != 1.0: @@ -28,28 +28,28 @@ def triton_fused_softmax_base( exp_logits = tl.exp(logits - new_max_logits) sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) max_logits = new_max_logits - return exp_logits, sum_exp_logits, max_logits, mask + return exp_logits, sum_exp_logits, max_logits @triton_jit() -def triton_cross_entropy_forward_parallel_kernel( +def triton_cross_entropy_forward_from_labels_parallel_kernel( logits_ptr, labels_ptr, - max_logits_ptr, - sum_exp_logits_ptr, - predicted_logits_ptr, - col_min: tl_constexpr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, - logits_scale_factor: tl_constexpr, block_size: tl_constexpr, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + predicted_logits_ptr=None, + col_min: tl_constexpr = 0, + logits_scale_factor: tl_constexpr = 1.0, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) logits_ptr = logits_ptr + block_idx * logits_stride_0 - exp_logits, sum_exp_logits, max_logits, mask = triton_fused_softmax_base( - logits_ptr, n_cols, logits_scale_factor, block_size + exp_logits, sum_exp_logits, max_logits = triton_fused_softmax_base( + logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor ) if labels_ptr is not None and predicted_logits_ptr is not None: @@ -63,33 +63,35 @@ def triton_cross_entropy_forward_parallel_kernel( predicted_logits *= logits_scale_factor tl.store(predicted_logits_ptr + block_idx, predicted_logits) - tl.store(max_logits_ptr + block_idx, max_logits) - tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) + if max_logits_ptr is not None: + tl.store(max_logits_ptr + block_idx, max_logits) + if sum_exp_logits_ptr is not None: + tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) @triton_jit() -def triton_cross_entropy_forward_backward_kernel( +def triton_cross_entropy_forward_backward_from_labels_kernel( logits_ptr, labels_ptr, - grad_logits_ptr, - losses_ptr, - max_logits_ptr, - sum_exp_logits_ptr, - grad_losses, - col_min: tl_constexpr, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, - grad_logits_stride_0: tl_constexpr, - logits_scale_factor: tl_constexpr, block_size: tl_constexpr, + losses_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + col_min: tl_constexpr = 0, + logits_scale_factor: tl_constexpr = 1.0, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) logits_ptr = logits_ptr + block_idx * logits_stride_0 if max_logits_ptr is None or sum_exp_logits_ptr is None: - exp_logits, sum_exp_logits, max_logits, mask = triton_fused_softmax_base( - logits_ptr, n_cols, logits_scale_factor, block_size + exp_logits, sum_exp_logits, max_logits = triton_fused_softmax_base( + logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor ) else: max_logits = tl.load(max_logits_ptr + block_idx) @@ -101,7 +103,6 @@ def triton_cross_entropy_forward_backward_kernel( if label_idx < 0 or label_idx >= n_cols: # Loss mask loss = 0.0 - predicted_logits = 0.0 else: predicted_logits = tl.load(logits_ptr + label_idx).to(tl.float32) if logits_scale_factor != 1.0: @@ -110,20 +111,21 @@ def triton_cross_entropy_forward_backward_kernel( tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: + if label_idx < -col_min: + grad_losses = 0.0 + elif logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor # Run in reverse order to maximize input and cache reuse. - for col_offset in tl.static_range((n_cols - 1) // block_size * block_size, -1, -block_size): - if max_logits_ptr is None or sum_exp_logits_ptr is None or col_offset != n_cols - block_size: - col_offsets = tl.arange(col_offset, col_offset + block_size) - mask = col_offsets < n_cols + col_offset_start = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) if logits_scale_factor != 1.0: logits *= logits_scale_factor exp_logits = tl.exp(logits - max_logits) - if label_idx < -col_min: - grad_losses = 0.0 - elif logits_scale_factor != 1.0: - grad_losses *= logits_scale_factor grad_base = exp_logits / sum_exp_logits if label_idx < 0 or label_idx >= n_cols: grad_logits = grad_base @@ -135,68 +137,209 @@ def triton_cross_entropy_forward_backward_kernel( @triton_jit() -def triton_cross_entropy_from_distribution_forward_backward_kernel( +def triton_predicted_logits_from_distribution( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + block_size: tl_constexpr, + from_logits: tl_constexpr = True, + target_logits_scale_factor: tl_constexpr = 1.0, + logits_scale_factor: tl_constexpr = 1.0, + unscaled_probabilities: tl_constexpr = False, # Skip division by sum_exp_logits in the logits case. +): + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + if from_logits: + target_logits = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if target_logits_scale_factor != 1.0: + target_logits *= target_logits_scale_factor + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + + if col_offset == 0: + max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + if from_logits: + target_max_logits = tl.max(target_logits, 0) + target_exp_logits = tl.exp(target_logits - target_max_logits) + target_sum_exp_logits = tl.sum(target_exp_logits, 0) + predicted_logits = tl.sum(tl.where(mask, target_exp_logits * logits, 0)) + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + predicted_logits = tl.sum(tl.where(mask, target * logits, 0)) + target_max_logits = None + target_sum_exp_logits = None + else: + new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) + exp_logits = tl.exp(logits - new_max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) + max_logits = new_max_logits + if from_logits: + target_new_max_logits = tl.maximum(tl.max(target_logits, 0), target_max_logits) + target_exp_logits = tl.exp(target_logits - target_new_max_logits) + target_sum_exp_logits = tl.sum(target_exp_logits, 0) + target_sum_exp_logits * tl.exp( + target_max_logits - target_new_max_logits + ) + predicted_logits = predicted_logits * tl.exp(target_max_logits - target_new_max_logits) + tl.sum( + tl.where(mask, target_exp_logits * logits, 0) + ) + target_max_logits = target_new_max_logits + else: + predicted_logits += tl.sum(tl.where(mask, target * logits, 0)) + + if from_logits: + target = target_exp_logits + if not unscaled_probabilities: + predicted_logits /= target_sum_exp_logits + target /= target_sum_exp_logits + + return predicted_logits, exp_logits, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target + + +@triton_jit() +def triton_cross_entropy_from_distribution_forward_parallel_kernel( logits_ptr, target_ptr, - loss_mask_ptr, - grad_logits_ptr, - losses_ptr, - grad_losses, n_cols: tl_constexpr, logits_stride_0: tl_constexpr, target_stride_0: tl_constexpr, - grad_logits_stride_0: tl_constexpr, - logits_scale_factor: tl_constexpr, - from_logits: tl_constexpr, block_size: tl_constexpr, + loss_mask_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + predicted_logits_ptr=None, + from_logits: tl_constexpr = True, + logits_scale_factor: tl_constexpr = 1.0, + target_logits_scale_factor: tl_constexpr = 1.0, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) - col_offsets = tl.arange(0, block_size) - mask = col_offsets < n_cols + logits_ptr = logits_ptr + block_idx * logits_stride_0 + target_ptr = target_ptr + block_idx * target_stride_0 - if loss_mask_ptr is not None: - loss_mask = tl.load(loss_mask_ptr + block_idx) - if loss_mask == 0: - tl.store(losses_ptr + block_idx, 0) - if grad_losses is not None: - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=mask) - return + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + tl.store(predicted_logits_ptr + block_idx, 0) + return - logits = tl.load(logits_ptr + block_idx * logits_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( - tl.float32 + predicted_logits, _, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target = ( + triton_predicted_logits_from_distribution( + logits_ptr, + target_ptr, + n_cols=n_cols, + block_size=block_size, + from_logits=from_logits, + logits_scale_factor=logits_scale_factor, + target_logits_scale_factor=target_logits_scale_factor, + unscaled_probabilities=True, + ) ) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor + if predicted_logits_ptr is not None: + tl.store(predicted_logits_ptr + block_idx, predicted_logits) + if max_logits_ptr is not None: + tl.store(max_logits_ptr + block_idx, max_logits) + if sum_exp_logits_ptr is not None: + tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) - max_logits = tl.max(logits, 0) - logits_norm = logits - max_logits - exp_logits = tl.exp(logits_norm) - sum_exp_logits = tl.sum(exp_logits, 0) + if target_max_logits_ptr is not None: + tl.store(target_max_logits_ptr + block_idx, target_max_logits) + if target_sum_exp_logits_ptr is not None: + tl.store(target_sum_exp_logits_ptr + block_idx, target_sum_exp_logits) - target = tl.load(target_ptr + block_idx * target_stride_0 + col_offsets, mask=mask, other=-float("inf")).to( - tl.float32 - ) - if from_logits: - if logits_scale_factor != 1.0: - target *= logits_scale_factor - max_target_logits = tl.max(target, 0) - exp_target_logits = tl.exp(target - max_target_logits) - sum_exp_target_logits = tl.sum(exp_target_logits, 0) - target = exp_target_logits / sum_exp_target_logits - # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) - loss = tl.log(sum_exp_logits) - tl.sum(tl.where(mask, target * logits_norm, 0), 0) - tl.store(losses_ptr + block_idx, loss) +@triton_jit() +def triton_cross_entropy_from_distribution_forward_backward_kernel( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + target_stride_0: tl_constexpr, + block_size: tl_constexpr, + loss_mask_ptr=None, + losses_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + from_logits: tl_constexpr = True, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + logits_scale_factor: tl_constexpr = 1.0, + target_logits_scale_factor: tl_constexpr = 1.0, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + target_ptr = target_ptr + block_idx * target_stride_0 + + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + if max_logits_ptr is None or sum_exp_logits_ptr is None: + predicted_logits, exp_logits, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target = ( + triton_predicted_logits_from_distribution( + logits_ptr, + target_ptr, + n_cols=n_cols, + block_size=block_size, + from_logits=from_logits, + logits_scale_factor=logits_scale_factor, + target_logits_scale_factor=target_logits_scale_factor, + ) + ) + else: + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + if grad_losses is not None and from_logits: + target_max_logits = tl.load(target_max_logits_ptr + block_idx) + target_sum_exp_logits = tl.load(target_sum_exp_logits_ptr + block_idx) + + if losses_ptr is not None: + # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) + loss = tl.log(sum_exp_logits) + max_logits - predicted_logits + tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: - # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) if logits_scale_factor != 1.0: - grad_logits *= logits_scale_factor - if loss_mask_ptr is not None: - grad_logits = grad_logits - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + grad_losses *= logits_scale_factor + # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. + col_offset_start = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + exp_logits = tl.exp(logits - max_logits) + if from_logits: + target_logits = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if target_logits_scale_factor != 1.0: + target_logits *= target_logits_scale_factor + target = tl.exp(target_logits - target_max_logits) / target_sum_exp_logits + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + + grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) @torch.compile @@ -208,8 +351,20 @@ def _rescale_sum_exp_logits( return sum_exp_logits * (local_max_logits - max_logits).exp() +def _parallel_sum_exp_logits( + sum_exp_logits: torch.Tensor, + local_max_logits: torch.Tensor, + group: torch.distributed.ProcessGroup | None, +) -> torch.Tensor: + max_logits = local_max_logits.clone() + torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) + sum_exp_logits = _rescale_sum_exp_logits(sum_exp_logits, local_max_logits, max_logits) + torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) + return max_logits, sum_exp_logits + + @torch.compile -def _calculate_loss( +def _cross_entropy_loss_from_labels( predicted_logits: torch.Tensor, target: torch.Tensor, sum_exp_logits: torch.Tensor, @@ -218,6 +373,30 @@ def _calculate_loss( return torch.where(target.flatten() >= 0, sum_exp_logits.log() + max_logits - predicted_logits, 0).mean() +@torch.compile +def _rescale_predicted_logits( + predicted_logits: torch.Tensor, + target_sum_exp_logits: torch.Tensor, + local_target_max_logits: torch.Tensor, + target_max_logits: torch.Tensor, +): + # We skipped the division by `target_sum_exp_logits` in the triton kernel so we do it here. + return predicted_logits * torch.exp(local_target_max_logits - target_max_logits) / target_sum_exp_logits + + +@torch.compile +def _cross_entropy_loss_from_distribution( + predicted_logits: torch.Tensor, + loss_mask: torch.Tensor | None, + sum_exp_logits: torch.Tensor, + max_logits: torch.Tensor, +) -> torch.Tensor: + per_sample_losses = sum_exp_logits.log() + max_logits - predicted_logits + if loss_mask is not None: + per_sample_losses = torch.where(loss_mask.flatten(), per_sample_losses, 0) + return per_sample_losses.mean() + + def triton_cross_entropy_forward_backward( logits: torch.Tensor, target: torch.Tensor, @@ -248,86 +427,134 @@ def triton_cross_entropy_forward_backward( block_size = min(triton.next_power_of_2(n_cols), 32768) if num_warps is None: num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + kwargs = { + "logits_stride_0": logits.stride(-2), + "n_cols": n_cols, + "logits_scale_factor": logits_scale_factor, + "block_size": block_size, + "num_warps": num_warps, + } + # TODO: Safe to do inplace? grad_logits = None if grad_output is None else torch.empty_like(logits) + backward_kwargs = ( + {} + if grad_output is None + else { + "grad_logits_ptr": grad_logits, + "grad_losses": grad_output / n_rows, + "grad_logits_stride_0": grad_logits.stride(-2), + } + ) if target_format == TargetFormat.labels: if group is None: losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( logits, target, - grad_logits, - losses, - None, - None, - None if grad_output is None else grad_output / n_rows, - 0, - n_cols, - logits.stride(-2), - None if grad_output is None else grad_logits.stride(-2), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, + losses_ptr=losses, + **kwargs, + **backward_kwargs, ) loss = losses.mean() else: predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) local_max_logits = torch.empty_like(predicted_logits) sum_exp_logits = torch.empty_like(predicted_logits) - triton_cross_entropy_forward_parallel_kernel[(n_rows,)]( + triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( logits, target, - local_max_logits, - sum_exp_logits, - predicted_logits, - n_cols * group.rank(), - n_cols, - logits.stride(-2), - logits_scale_factor, - block_size=block_size, + max_logits_ptr=local_max_logits, + sum_exp_logits_ptr=sum_exp_logits, + predicted_logits_ptr=predicted_logits, + col_min=n_cols * group.rank(), + **kwargs, ) - max_logits = local_max_logits.clone() - torch.distributed.all_reduce(max_logits, op=torch.distributed.ReduceOp.MAX, group=group) - sum_exp_logits = _rescale_sum_exp_logits(sum_exp_logits, local_max_logits, max_logits) - torch.distributed.all_reduce(sum_exp_logits, op=torch.distributed.ReduceOp.SUM, group=group) + max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) - loss = _calculate_loss(predicted_logits, target, sum_exp_logits, max_logits) - triton_cross_entropy_forward_backward_kernel[(n_rows,)]( + loss = _cross_entropy_loss_from_labels(predicted_logits, target, sum_exp_logits, max_logits) + if grad_output is not None: + triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( + logits, + target, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + col_min=n_cols * group.rank(), + **kwargs, + **backward_kwargs, + ) + else: + if group is None: + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if loss_mask is not None: + assert loss_mask.is_contiguous() + triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, target, - grad_logits, - None, - max_logits, - sum_exp_logits, - None if grad_output is None else grad_output / n_rows, - n_cols * group.rank(), - n_cols, - logits.stride(-2), - None if grad_output is None else grad_logits.stride(-2), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, + loss_mask_ptr=loss_mask, + losses_ptr=losses, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + target_stride_0=target.stride(-2), + target_logits_scale_factor=logits_scale_factor / temperature, + from_logits=target_format == TargetFormat.logits, + **kwargs, + **backward_kwargs, + ) + loss = losses.mean() + else: + predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(predicted_logits) + sum_exp_logits = torch.empty_like(predicted_logits) + if target_format == TargetFormat.logits: + local_target_max_logits = torch.empty_like(predicted_logits) + target_sum_exp_logits = torch.empty_like(predicted_logits) + else: + local_target_max_logits = target_sum_exp_logits = None + + triton_cross_entropy_from_distribution_forward_parallel_kernel[(n_rows,)]( + logits, + target, + loss_mask_ptr=loss_mask, + max_logits_ptr=local_max_logits, + sum_exp_logits_ptr=sum_exp_logits, + target_max_logits_ptr=local_target_max_logits, + target_sum_exp_logits_ptr=target_sum_exp_logits, + predicted_logits_ptr=predicted_logits, + target_stride_0=target.stride(-2), + target_logits_scale_factor=logits_scale_factor / temperature, + from_logits=target_format == TargetFormat.logits, + **kwargs, + **backward_kwargs, + ) + max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + if target_format == TargetFormat.logits: + target_max_logits, target_sum_exp_logits = _parallel_sum_exp_logits( + target_sum_exp_logits, local_target_max_logits, group + ) + predicted_logits = _rescale_predicted_logits( + predicted_logits, target_sum_exp_logits, local_target_max_logits, target_max_logits + ) + else: + target_max_logits = None + torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) + + loss = _cross_entropy_loss_from_distribution(predicted_logits, loss_mask, sum_exp_logits, max_logits) + triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( + logits, + target, + loss_mask_ptr=loss_mask, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + target_max_logits_ptr=target_max_logits, + target_sum_exp_logits_ptr=target_sum_exp_logits, + predicted_logits_ptr=predicted_logits, + target_stride_0=target.stride(-2), + target_logits_scale_factor=logits_scale_factor / temperature, + from_logits=target_format == TargetFormat.logits, + **kwargs, + **backward_kwargs, ) - else: - assert group is None - losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - if loss_mask is not None: - assert loss_mask.is_contiguous() - triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( - logits, - target / temperature, - loss_mask, - grad_logits, - losses, - None if grad_output is None else grad_output / n_rows, - n_cols, - logits.stride(-2), - target.stride(-2), - None if grad_output is None else grad_logits.stride(-2), - logits_scale_factor, - block_size=block_size, - num_warps=num_warps, - from_logits=target_format == TargetFormat.logits, - ) - loss = losses.mean() return loss, grad_logits diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 1a31db90a..37cd99fb0 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -10,6 +10,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedBackend from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.triton import triton_available from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.entropy_loss import entropy_loss_forward_backward from fast_llm.layers.language_model.loss.loss import loss_forward_backward @@ -18,9 +19,6 @@ from tests.utils.dataset import get_random_spans from tests.utils.subtest import DistributedTestContext -VOCAB_SIZE = 100 -NUM_TOKENS = 200 - def _get_lm_loss_inputs( num_columns: int, loss_masking: bool, target_format: TargetFormat, batch_shape: tuple[int], dtype @@ -108,15 +106,15 @@ def reference_dpo_loss( _BATCH_SHAPES = ((64,), (16, 8)) _LOSS_PARAMETERS = ( - (500, 1.0, 1.0, False, DataType.float32), # Simple - (512, 1.0, 1.0, False, DataType.float32), # Power of 2 - (500, None, 1.0, False, DataType.float32), # No grad - (500, 1.0, 4.0, False, DataType.float32), # Loss scaling - (500, 4.0, 1.0, False, DataType.float32), # Grad scaling - (500, 1.0, 1.0, True, DataType.float32), # Loss masking - (500, 1.0, 1.0, False, DataType.float16), # Fp16 - (500, 1.0, 1.0, True, DataType.bfloat16), # Bf16, loss masking - (65538, 1.0, 1.0, False, DataType.float32), # Above max block size + (500, 1.0, 1.0, False, DataType.float32, None), # Simple + (256, 1.0, 1.0, False, DataType.float32, None), # Power of 2 + (500, None, 1.0, False, DataType.float32, None), # No grad + (500, 1.0, 4.0, False, DataType.float32, None), # Loss scaling + (500, 4.0, 1.0, False, DataType.float32, None), # Grad scaling + (500, 1.0, 1.0, True, DataType.float32, None), # Loss masking + (500, 1.0, 1.0, False, DataType.float16, None), # Fp16 + (500, 1.0, 1.0, False, DataType.float32, 256), # Looped + (1000, 2.0, 3.0, True, DataType.float16, 256), # Hard ) @@ -129,12 +127,15 @@ def _test_entropy_loss( target_format, entropy_loss_type, dtype, + block_size, group=None, ): if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: pytest.skip(reason="Not implemented") # TODO: Test tensor-parallel implementation. logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, target_format, batch_shape, dtype) + local_logits = split_op(logits, group, -1).contiguous() + local_target = target if target_format == TargetFormat.labels else split_op(target, group, -1).contiguous() # Torch serves as the reference implementation. out_ref, grad_ref = entropy_loss_forward_backward( logits=logits, @@ -147,8 +148,8 @@ def _test_entropy_loss( implementation=EntropyLossImplementation.torch, ) out_fused, grad_fused = entropy_loss_forward_backward( - logits=split_op(logits, group, -1), - target=target if target_format == TargetFormat.labels else split_op(target, group, -1), + logits=local_logits, + target=local_target, loss_mask=loss_mask, grad_output=grad_output, group=group, @@ -157,7 +158,6 @@ def _test_entropy_loss( entropy_loss_type=entropy_loss_type, implementation=EntropyLossImplementation.fused, ) - _compare_losses_and_grads( out_fused, out_ref, @@ -168,21 +168,23 @@ def _test_entropy_loss( group=group, ) - if entropy_loss_type != EntropyLossType.cross_entropy or not torch.cuda.is_available() or group is not None: + if entropy_loss_type != EntropyLossType.cross_entropy or not triton_available: # Triton implementation only supports cross-entropy. return assert TritonConfig.TRITON_ENABLED out_triton, grad_triton = entropy_loss_forward_backward( - logits=logits, - target=target, + logits=local_logits, + target=local_target, loss_mask=loss_mask, grad_output=grad_output, logits_scale_factor=logits_scale_factor, target_format=target_format, entropy_loss_type=entropy_loss_type, implementation=EntropyLossImplementation.triton, + group=group, + block_size=block_size, ) - _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref) + _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref, group=group) def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): @@ -201,18 +203,34 @@ def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, los group=group, logits_scale_factor=logits_scale_factor, ) - _compare_losses_and_grads(out_fused, out_ref, grad_output is not None, grad_fused, grad_ref, group=group) + _compare_losses_and_grads( + out_fused, + out_ref, + grad_output is not None, + grad_fused, + grad_ref, + threshold=1e-5 if data_type == DataType.float32 else 1e-4, + group=group, + ) @pytest.mark.slow @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) -@pytest.mark.parametrize("target_format", TargetFormat) +@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) @pytest.mark.parametrize("entropy_loss_type", EntropyLossType) def test_entropy_loss( - batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, target_format, entropy_loss_type, dtype + batch_shape, + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + target_format, + entropy_loss_type, + dtype, + block_size, ): _test_entropy_loss( batch_shape, @@ -223,23 +241,24 @@ def test_entropy_loss( target_format, entropy_loss_type, dtype, + block_size, ) @pytest.mark.slow @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) -def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype): +def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size): _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype) @pytest.mark.skip(reason="DPO loss is broken") def test_dpo_loss(): - logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) - reference_model_logits = torch.normal(0, 1, (NUM_TOKENS, VOCAB_SIZE)) - labels = torch.randint(0, VOCAB_SIZE, (NUM_TOKENS,)) + logits = torch.normal(0, 1, (200, 100)) + reference_model_logits = torch.normal(0, 1, (200, 100)) + labels = torch.randint(0, 100, (200,)) spans = get_random_spans(np.full(10, 50), 0, 10) fast_llm_loss = dpo_loss(logits, labels, reference_model_logits, spans[::2], spans[1::2]) @@ -249,8 +268,8 @@ def test_dpo_loss(): def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path, seed: int): for batch_shape in _BATCH_SHAPES: - for num_columns, grad_output, logits_scale_factor, loss_masking, dtype in _LOSS_PARAMETERS: - suffix = f"{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{"_".join([str(i) for i in batch_shape])}" + for num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size in _LOSS_PARAMETERS: + suffix = f"{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{block_size}-{"_".join([str(i) for i in batch_shape])}" # Entropy loss for entropy_loss_type in EntropyLossType: for target_format in TargetFormat: @@ -270,6 +289,7 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa target_format, entropy_loss_type, dtype, + block_size, test_context.group, ) # Z loss @@ -302,8 +322,8 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): _run_lm_loss_distributed, (result_path / "test_losses", random.randint(0, 2**32 - 1)), world_size=2, - backend=DistributedBackend.gloo, - use_cuda=False, # Disable device count check. + backend=DistributedBackend.nccl if (use_nccl := torch.cuda.device_count() >= 2) else DistributedBackend.gloo, + use_cuda=use_nccl, # Disable device count check. ) @@ -311,7 +331,7 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): @pytest.mark.depends_on(on=["test_lm_loss_distributed_dependency"]) @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) @pytest.mark.parametrize( "loss_type", @@ -335,10 +355,11 @@ def test_lm_loss_distributed( logits_scale_factor, loss_masking, dtype, + block_size, ): report_subtest( result_path - / f"test_losses/{loss_type}-{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{"_".join([str(i) for i in batch_shape])}", + / f"test_losses/{loss_type}-{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{block_size}-{"_".join([str(i) for i in batch_shape])}", 2, use_cuda=False, ) From 1d0439e0c6419a7040932ba520a89ba2554437fe Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Thu, 5 Feb 2026 05:06:43 -0500 Subject: [PATCH 11/22] Forward KL --- fast_llm/functional/entropy_loss.py | 4 +- fast_llm/functional/triton/cross_entropy.py | 41 ++++++++++++++++++--- tests/layers/test_lm_losses.py | 4 +- 3 files changed, 40 insertions(+), 9 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 25e1ae317..4d39b3a77 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -183,7 +183,7 @@ def _fused_cross_entropy_base_from_distribution( # KL loss = mean(log(sum_exp_logits) - sum(probabilities * (logits - log_probabilities)) if return_kl_loss: if target_format == TargetFormat.logits: - target_log_probability = target_logits_norm - sum_exp_target_logits.log().unsqueeze(-1) + target_log_probability = target_logits_norm else: target_log_probability = torch.log(target) logits_norm = logits_norm - target_log_probability @@ -194,6 +194,8 @@ def _fused_cross_entropy_base_from_distribution( all_reduce(predicted_logits, op=ReduceOp.SUM, group=group) per_sample_loss = sum_exp_logits.log() - predicted_logits + if return_kl_loss and target_format == TargetFormat.logits: + per_sample_loss = per_sample_loss - sum_exp_target_logits.log() if grad_output is None: grad = None diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 6fb8e930d..516df0463 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -146,6 +146,7 @@ def triton_predicted_logits_from_distribution( target_logits_scale_factor: tl_constexpr = 1.0, logits_scale_factor: tl_constexpr = 1.0, unscaled_probabilities: tl_constexpr = False, # Skip division by sum_exp_logits in the logits case. + return_kl_loss: tl.constexpr = False, ): for col_offset in tl.static_range(0, n_cols, block_size): col_offsets = tl_arange(col_offset, col_offset + block_size) @@ -169,10 +170,14 @@ def triton_predicted_logits_from_distribution( target_max_logits = tl.max(target_logits, 0) target_exp_logits = tl.exp(target_logits - target_max_logits) target_sum_exp_logits = tl.sum(target_exp_logits, 0) - predicted_logits = tl.sum(tl.where(mask, target_exp_logits * logits, 0)) + # entropy = sum(logits*exp_logits)/sum_exp_logits - log_sum_exp_logits + # `log_sum_exp_logits` term and division by `sum_exp_logits` kept for later, + logits_shifted = logits - target_logits if return_kl_loss else logits + predicted_logits = tl.sum(tl.where(mask, target_exp_logits * logits_shifted, 0)) else: target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - predicted_logits = tl.sum(tl.where(mask, target * logits, 0)) + logits_shifted = logits - tl.log(target) if return_kl_loss else logits + predicted_logits = tl.sum(tl.where(mask, target * logits_shifted, 0)) target_max_logits = None target_sum_exp_logits = None else: @@ -186,18 +191,22 @@ def triton_predicted_logits_from_distribution( target_sum_exp_logits = tl.sum(target_exp_logits, 0) + target_sum_exp_logits * tl.exp( target_max_logits - target_new_max_logits ) + logits_shifted = logits - target_logits if return_kl_loss else logits predicted_logits = predicted_logits * tl.exp(target_max_logits - target_new_max_logits) + tl.sum( - tl.where(mask, target_exp_logits * logits, 0) + tl.where(mask, target_exp_logits * logits_shifted, 0) ) target_max_logits = target_new_max_logits else: - predicted_logits += tl.sum(tl.where(mask, target * logits, 0)) + logits_shifted = logits - tl.log(target) if return_kl_loss else logits + predicted_logits += tl.sum(tl.where(mask, target * logits_shifted, 0)) if from_logits: target = target_exp_logits if not unscaled_probabilities: predicted_logits /= target_sum_exp_logits target /= target_sum_exp_logits + if return_kl_loss: + predicted_logits = predicted_logits + tl.log(target_sum_exp_logits) + target_max_logits return predicted_logits, exp_logits, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target @@ -219,6 +228,7 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( from_logits: tl_constexpr = True, logits_scale_factor: tl_constexpr = 1.0, target_logits_scale_factor: tl_constexpr = 1.0, + return_kl_loss: tl.constexpr = False, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -240,6 +250,7 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( logits_scale_factor=logits_scale_factor, target_logits_scale_factor=target_logits_scale_factor, unscaled_probabilities=True, + return_kl_loss=return_kl_loss, ) ) if predicted_logits_ptr is not None: @@ -275,6 +286,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( grad_logits_stride_0: tl_constexpr = None, logits_scale_factor: tl_constexpr = 1.0, target_logits_scale_factor: tl_constexpr = 1.0, + return_kl_loss: tl.constexpr = False, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -303,6 +315,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( from_logits=from_logits, logits_scale_factor=logits_scale_factor, target_logits_scale_factor=target_logits_scale_factor, + return_kl_loss=return_kl_loss, ) ) else: @@ -390,7 +403,12 @@ def _cross_entropy_loss_from_distribution( loss_mask: torch.Tensor | None, sum_exp_logits: torch.Tensor, max_logits: torch.Tensor, + target_sum_exp_logits: torch.Tensor | None, + target_max_logits: torch.Tensor | None, + return_kl_loss: bool = False, ) -> torch.Tensor: + if return_kl_loss: + predicted_logits = predicted_logits + target_sum_exp_logits.log() + target_max_logits per_sample_losses = sum_exp_logits.log() + max_logits - predicted_logits if loss_mask is not None: per_sample_losses = torch.where(loss_mask.flatten(), per_sample_losses, 0) @@ -417,7 +435,7 @@ def triton_cross_entropy_forward_backward( TODO: Better handling of `grad_output = None` """ assert TritonConfig.TRITON_ENABLED - Assert.eq(entropy_loss_type, EntropyLossType.cross_entropy) + Assert.incl(entropy_loss_type, (EntropyLossType.cross_entropy, EntropyLossType.forward_kl)) # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() @@ -500,6 +518,7 @@ def triton_cross_entropy_forward_backward( target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, **kwargs, **backward_kwargs, ) @@ -526,6 +545,7 @@ def triton_cross_entropy_forward_backward( target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, **kwargs, **backward_kwargs, ) @@ -541,7 +561,16 @@ def triton_cross_entropy_forward_backward( target_max_logits = None torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) - loss = _cross_entropy_loss_from_distribution(predicted_logits, loss_mask, sum_exp_logits, max_logits) + loss = _cross_entropy_loss_from_distribution( + predicted_logits, + loss_mask, + sum_exp_logits, + max_logits, + target_sum_exp_logits=target_sum_exp_logits, + target_max_logits=target_max_logits, + return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl + and target_format == TargetFormat.logits, + ) triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( logits, target, diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 37cd99fb0..e54197204 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -168,7 +168,7 @@ def _test_entropy_loss( group=group, ) - if entropy_loss_type != EntropyLossType.cross_entropy or not triton_available: + if entropy_loss_type == EntropyLossType.reverse_kl or not triton_available: # Triton implementation only supports cross-entropy. return assert TritonConfig.TRITON_ENABLED @@ -219,7 +219,7 @@ def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, los @pytest.mark.parametrize( ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) -@pytest.mark.parametrize("target_format", (TargetFormat.logits,)) +@pytest.mark.parametrize("target_format", TargetFormat) @pytest.mark.parametrize("entropy_loss_type", EntropyLossType) def test_entropy_loss( batch_shape, From 1b40518e3550710d03628fc82adde8d6c4811ab3 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Feb 2026 04:17:43 -0500 Subject: [PATCH 12/22] Reverse KL, triton tweaks --- fast_llm/engine/config_utils/run.py | 2 +- fast_llm/functional/config.py | 13 + fast_llm/functional/triton/__init__.py | 13 +- fast_llm/functional/triton/adam.py | 6 +- fast_llm/functional/triton/cross_entropy.py | 495 ++++++++++++++---- fast_llm/functional/triton/mlp.py | 13 +- fast_llm/functional/triton/normalization.py | 16 +- fast_llm/functional/triton/pointwise.py | 19 +- fast_llm/functional/triton/rotary.py | 6 +- fast_llm/functional/triton/sparse_copy.py | 17 +- fast_llm/functional/triton/sparse_linear.py | 30 +- fast_llm/layers/attention/config.py | 5 - fast_llm/layers/attention/rotary/rotary.py | 12 +- .../common/normalization/normalization.py | 4 +- fast_llm/layers/decoder/mlp/mlp.py | 4 +- fast_llm/layers/language_model/loss/config.py | 21 +- .../language_model/loss/entropy_loss.py | 44 +- tests/conftest.py | 5 + tests/functional/test_functional.py | 32 +- tests/functional/test_sparse_matmul.py | 53 +- tests/functional/test_triton_kernels.py | 81 ++- tests/layers/test_lm_losses.py | 34 +- tests/layers/test_rotary.py | 15 +- tests/layers/test_ssm.py | 4 +- tests/models/test_checkpoint.py | 13 +- tests/test_loss_mask.py | 5 +- tests/utils/model_configs.py | 2 + tests/utils/utils.py | 2 + 28 files changed, 628 insertions(+), 338 deletions(-) diff --git a/fast_llm/engine/config_utils/run.py b/fast_llm/engine/config_utils/run.py index 2c6c8105f..baa386337 100644 --- a/fast_llm/engine/config_utils/run.py +++ b/fast_llm/engine/config_utils/run.py @@ -101,7 +101,7 @@ def configure_logging( def get_run(self, distributed: "Distributed") -> "Run": from fast_llm.functional.config import TritonConfig - TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels # and distributed.config.use_cuda + TritonConfig.TRITON_ENABLED = self.run.enable_triton_kernels TritonConfig.TRITON_LINEAR = self.run.triton_linear_kernels run = Run(config=self, distributed=distributed) set_global_variables(not self.run.torch_dynamo_enable) diff --git a/fast_llm/functional/config.py b/fast_llm/functional/config.py index 050c700c9..f863a99ac 100644 --- a/fast_llm/functional/config.py +++ b/fast_llm/functional/config.py @@ -14,6 +14,19 @@ class TritonConfig: POINTWISE_BLOCK_SIZE = 1024 MAX_BLOCK_SIZE_BYTES = 65536 + @classmethod + def enabled(cls, device: "torch.device|None" = None, default: bool | None = None) -> bool: + if default is False: + return False + from fast_llm.functional.triton import triton_available, triton_interpret + + available = triton_available and (device is None or device.type == "cuda" or triton_interpret) + if default is None: + default = available and cls.TRITON_ENABLED + else: + assert available + return default + class MLPRecomputeLevel(enum.StrEnum): none = "none" diff --git a/fast_llm/functional/triton/__init__.py b/fast_llm/functional/triton/__init__.py index 82f67621e..61ead1c60 100644 --- a/fast_llm/functional/triton/__init__.py +++ b/fast_llm/functional/triton/__init__.py @@ -9,25 +9,32 @@ tl_constexpr = tl.constexpr TritonConfig = triton.Config - triton_available = torch.cuda.is_available() or triton.knobs.runtime.interpret + # Use `TRITON_INTERPRET=1` to enable triton on CPU. + triton_interpret = triton.knobs.runtime.interpret + triton_available = torch.cuda.is_available() or triton_interpret except ImportError as e: triton = InvalidObject(e) tl = triton tl_constexpr = None TritonConfig = lambda *args, **kwargs: None + triton_interpret = False triton_available = False triton_jit = try_decorate(lambda: triton.jit) triton_autotune = try_decorate(lambda: triton.autotune) - if not triton_available: tl_arange = None -elif triton.knobs.runtime.interpret: + tl_full = None +elif triton_interpret: # Workaround for a triton bug. @triton_jit def tl_arange(start, end): return tl.arange(int(start), int(end)) + @triton_jit + def tl_full(shape, value, dtype): + return tl.full(tuple(int(x) for x in shape), value, dtype) + else: tl_arange = tl.arange diff --git a/fast_llm/functional/triton/adam.py b/fast_llm/functional/triton/adam.py index 07ba2df4e..2c835ca05 100644 --- a/fast_llm/functional/triton/adam.py +++ b/fast_llm/functional/triton/adam.py @@ -8,7 +8,7 @@ from torch.optim.adamw import adamw # noqa from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit @triton_jit() @@ -37,7 +37,7 @@ def triton_adam_kernel( # TODO: Int64 ptr only if needed? block_start = tl.program_id(axis=0).to(tl.int64) * block_size - offsets = block_start + tl.arange(0, block_size) + offsets = block_start + tl_arange(0, block_size) mask = offsets < numel params = tl.load(params_ptr + offsets, mask=mask) @@ -75,7 +75,7 @@ def triton_adam( epsilon: float, use_triton=True, ) -> None: - if not use_triton or (use_triton is None and TritonConfig.TRITON_ENABLED): + if not TritonConfig.enabled(params.device, use_triton): if noop_flag.item() == 0: return adamw( [params], diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/cross_entropy.py index 516df0463..335048770 100644 --- a/fast_llm/functional/triton/cross_entropy.py +++ b/fast_llm/functional/triton/cross_entropy.py @@ -1,34 +1,61 @@ import torch -from fast_llm.functional.config import EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit -from fast_llm.utils import Assert @triton_jit() -def triton_fused_softmax_base( +def triton_fused_softmax_iter_base( logits_ptr, + col_offset: tl.constexpr, n_cols: tl_constexpr, block_size: tl_constexpr, + max_logits=None, + sum_exp_logits=None, + col_offsets=None, + mask=None, logits_scale_factor: tl_constexpr = 1.0, ): - for col_offset in tl.static_range(0, n_cols, block_size): + if col_offsets is None: col_offsets = tl_arange(col_offset, col_offset + block_size) + if mask is None: mask = col_offsets < n_cols - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + + if col_offset == 0: + new_max_logits = tl.max(logits, 0) + exp_logits = tl.exp(logits - new_max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + else: + new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) + exp_logits = tl.exp(logits - new_max_logits) + sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) + return logits, exp_logits, sum_exp_logits, new_max_logits, col_offsets, mask - if col_offset == 0: - max_logits = tl.max(logits, 0) - exp_logits = tl.exp(logits - max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) - else: - new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) - exp_logits = tl.exp(logits - new_max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) - max_logits = new_max_logits - return exp_logits, sum_exp_logits, max_logits + +@triton_jit() +def triton_fused_softmax_base( + logits_ptr, + n_cols: tl_constexpr, + block_size: tl_constexpr, + logits_scale_factor: tl_constexpr = 1.0, +): + exp_logits = None + sum_exp_logits = None + max_logits = None + for col_offset in tl.static_range(0, n_cols, block_size): + logits, exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_iter_base( + logits_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=max_logits, + sum_exp_logits=sum_exp_logits, + logits_scale_factor=logits_scale_factor, + ) + return exp_logits, sum_exp_logits, max_logits, col_offsets, mask @triton_jit() @@ -48,7 +75,7 @@ def triton_cross_entropy_forward_from_labels_parallel_kernel( block_idx = tl.program_id(0).to(tl.int64) logits_ptr = logits_ptr + block_idx * logits_stride_0 - exp_logits, sum_exp_logits, max_logits = triton_fused_softmax_base( + exp_logits, sum_exp_logits, max_logits, _, _ = triton_fused_softmax_base( logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor ) @@ -90,7 +117,7 @@ def triton_cross_entropy_forward_backward_from_labels_kernel( logits_ptr = logits_ptr + block_idx * logits_stride_0 if max_logits_ptr is None or sum_exp_logits_ptr is None: - exp_logits, sum_exp_logits, max_logits = triton_fused_softmax_base( + exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_base( logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor ) else: @@ -116,11 +143,11 @@ def triton_cross_entropy_forward_backward_from_labels_kernel( elif logits_scale_factor != 1.0: grad_losses *= logits_scale_factor # Run in reverse order to maximize input and cache reuse. - col_offset_start = (n_cols - 1) // block_size * block_size + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size for col_offset in tl.static_range(col_offset_start, -1, -block_size): - col_offsets = tl_arange(col_offset, col_offset + block_size) - mask = col_offsets < n_cols if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) if logits_scale_factor != 1.0: logits *= logits_scale_factor @@ -145,68 +172,68 @@ def triton_predicted_logits_from_distribution( from_logits: tl_constexpr = True, target_logits_scale_factor: tl_constexpr = 1.0, logits_scale_factor: tl_constexpr = 1.0, - unscaled_probabilities: tl_constexpr = False, # Skip division by sum_exp_logits in the logits case. + return_partial_loss: tl_constexpr = False, # Skip division by sum_exp_logits in the logits case. return_kl_loss: tl.constexpr = False, ): + max_logits = None + sum_exp_logits = None + if from_logits: + target_max_logits = None + target_sum_exp_logits = None + else: + target_max_logits = 0 + target_sum_exp_logits = 0 + for col_offset in tl.static_range(0, n_cols, block_size): - col_offsets = tl_arange(col_offset, col_offset + block_size) - mask = col_offsets < n_cols - logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if logits_scale_factor != 1.0: - logits *= logits_scale_factor + logits, exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_iter_base( + logits_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=max_logits, + sum_exp_logits=sum_exp_logits, + logits_scale_factor=logits_scale_factor, + ) if from_logits: - target_logits = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - if target_logits_scale_factor != 1.0: - target_logits *= target_logits_scale_factor - else: - target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - - if col_offset == 0: - max_logits = tl.max(logits, 0) - exp_logits = tl.exp(logits - max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) - if from_logits: - target_max_logits = tl.max(target_logits, 0) - target_exp_logits = tl.exp(target_logits - target_max_logits) - target_sum_exp_logits = tl.sum(target_exp_logits, 0) - # entropy = sum(logits*exp_logits)/sum_exp_logits - log_sum_exp_logits - # `log_sum_exp_logits` term and division by `sum_exp_logits` kept for later, + target_logits, target_exp_logits, target_sum_exp_logits, target_new_max_logits, _, _ = ( + triton_fused_softmax_iter_base( + target_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=target_max_logits, + sum_exp_logits=target_sum_exp_logits, + logits_scale_factor=target_logits_scale_factor, + col_offsets=col_offsets, + mask=mask, + ) + ) + if col_offset == 0: logits_shifted = logits - target_logits if return_kl_loss else logits predicted_logits = tl.sum(tl.where(mask, target_exp_logits * logits_shifted, 0)) else: - target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) - logits_shifted = logits - tl.log(target) if return_kl_loss else logits - predicted_logits = tl.sum(tl.where(mask, target * logits_shifted, 0)) - target_max_logits = None - target_sum_exp_logits = None - else: - new_max_logits = tl.maximum(tl.max(logits, 0), max_logits) - exp_logits = tl.exp(logits - new_max_logits) - sum_exp_logits = tl.sum(exp_logits, 0) + sum_exp_logits * tl.exp(max_logits - new_max_logits) - max_logits = new_max_logits - if from_logits: - target_new_max_logits = tl.maximum(tl.max(target_logits, 0), target_max_logits) - target_exp_logits = tl.exp(target_logits - target_new_max_logits) - target_sum_exp_logits = tl.sum(target_exp_logits, 0) + target_sum_exp_logits * tl.exp( - target_max_logits - target_new_max_logits - ) logits_shifted = logits - target_logits if return_kl_loss else logits predicted_logits = predicted_logits * tl.exp(target_max_logits - target_new_max_logits) + tl.sum( tl.where(mask, target_exp_logits * logits_shifted, 0) ) - target_max_logits = target_new_max_logits + target_max_logits = target_new_max_logits + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if col_offset == 0: + logits_shifted = logits - tl.log(target) if return_kl_loss else logits + predicted_logits = tl.sum(tl.where(mask, target * logits_shifted, 0)) else: logits_shifted = logits - tl.log(target) if return_kl_loss else logits predicted_logits += tl.sum(tl.where(mask, target * logits_shifted, 0)) if from_logits: target = target_exp_logits - if not unscaled_probabilities: + if not return_partial_loss: predicted_logits /= target_sum_exp_logits target /= target_sum_exp_logits if return_kl_loss: - predicted_logits = predicted_logits + tl.log(target_sum_exp_logits) + target_max_logits + predicted_logits += tl.log(target_sum_exp_logits) + target_max_logits return predicted_logits, exp_logits, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target @@ -224,7 +251,7 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( sum_exp_logits_ptr=None, target_max_logits_ptr=None, target_sum_exp_logits_ptr=None, - predicted_logits_ptr=None, + partial_losses_ptr=None, from_logits: tl_constexpr = True, logits_scale_factor: tl_constexpr = 1.0, target_logits_scale_factor: tl_constexpr = 1.0, @@ -237,7 +264,7 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: # This entry is masked, ignore. - tl.store(predicted_logits_ptr + block_idx, 0) + tl.store(partial_losses_ptr + block_idx, 0) return predicted_logits, _, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits, target = ( @@ -249,12 +276,12 @@ def triton_cross_entropy_from_distribution_forward_parallel_kernel( from_logits=from_logits, logits_scale_factor=logits_scale_factor, target_logits_scale_factor=target_logits_scale_factor, - unscaled_probabilities=True, + return_partial_loss=True, return_kl_loss=return_kl_loss, ) ) - if predicted_logits_ptr is not None: - tl.store(predicted_logits_ptr + block_idx, predicted_logits) + if partial_losses_ptr is not None: + tl.store(partial_losses_ptr + block_idx, predicted_logits) if max_logits_ptr is not None: tl.store(max_logits_ptr + block_idx, max_logits) if sum_exp_logits_ptr is not None: @@ -275,6 +302,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target_stride_0: tl_constexpr, block_size: tl_constexpr, loss_mask_ptr=None, + partial_losses_ptr=None, losses_ptr=None, max_logits_ptr=None, sum_exp_logits_ptr=None, @@ -321,11 +349,22 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( else: max_logits = tl.load(max_logits_ptr + block_idx) sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) - if grad_losses is not None and from_logits: + if from_logits: target_max_logits = tl.load(target_max_logits_ptr + block_idx) target_sum_exp_logits = tl.load(target_sum_exp_logits_ptr + block_idx) if losses_ptr is not None: + # if return_kl_loss: + # predicted_logits = predicted_logits + target_sum_exp_logits.log() + target_max_logits + # per_sample_losses = sum_exp_logits.log() + max_logits - predicted_logits + # if loss_mask is not None: + # per_sample_losses = torch.where(loss_mask.flatten(), per_sample_losses, 0) + if partial_losses_ptr is not None: + predicted_logits = tl.load(partial_losses_ptr + block_idx) + if from_logits: + predicted_logits /= target_sum_exp_logits + if return_kl_loss: + predicted_logits += tl.log(target_sum_exp_logits) + target_max_logits # per_sample_loss = log(sum_exp_logits) - sum(probabilities * logits) loss = tl.log(sum_exp_logits) + max_logits - predicted_logits tl.store(losses_ptr + block_idx, loss) @@ -334,7 +373,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( if logits_scale_factor != 1.0: grad_losses *= logits_scale_factor # grad / grad_output = exp_logits / sum_exp_logits - target_probabilities. - col_offset_start = (n_cols - 1) // block_size * block_size + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size for col_offset in tl.static_range(col_offset_start, -1, -block_size): col_offsets = tl_arange(col_offset, col_offset + block_size) mask = col_offsets < n_cols @@ -355,6 +394,239 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) +@triton_jit() +def triton_reverse_kl_forward_from_distribution( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + block_size: tl_constexpr, + from_logits: tl_constexpr = True, + target_logits_scale_factor: tl_constexpr = 1.0, + logits_scale_factor: tl_constexpr = 1.0, + return_partial_loss: tl_constexpr = False, +): + max_logits = None + sum_exp_logits = None + if from_logits: + target_max_logits = None + target_sum_exp_logits = None + else: + target_max_logits = 0 + target_sum_exp_logits = 0 + + for col_offset in tl.static_range(0, n_cols, block_size): + logits, exp_logits, sum_exp_logits, new_max_logits, col_offsets, mask = triton_fused_softmax_iter_base( + logits_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=max_logits, + sum_exp_logits=sum_exp_logits, + logits_scale_factor=logits_scale_factor, + ) + + # print("sum_exp_logits", sum_exp_logits) + # print("max_logits", new_max_logits) + if from_logits: + # log_target excludes the log_sum_exp term to be added later + log_target, _, target_sum_exp_logits, target_new_max_logits, _, _ = triton_fused_softmax_iter_base( + target_ptr, + col_offset=col_offset, + n_cols=n_cols, + block_size=block_size, + max_logits=target_max_logits, + sum_exp_logits=target_sum_exp_logits, + logits_scale_factor=target_logits_scale_factor, + col_offsets=col_offsets, + mask=mask, + ) + target = log_target + # print("target_sum_exp_logits", target_sum_exp_logits) + # print("new_max_logits", target_new_max_logits) + else: + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + log_target = tl.log(target) + if col_offset == 0: + # predicted_log_probability=logits - new_max_logits - tl.log(sum_exp_logits) + # target_log_probability=log_target-target_new_max_logits-tl.log(target_sum_exp_logits) + # print("predicted_log_probability", predicted_log_probability) + # print("target_log_probability", target_log_probability) + # print("IUWH", exp_logits * (predicted_log_probability-target_log_probability)/sum_exp_logits) + loss = tl.sum(tl.where(mask, exp_logits * (logits - log_target), 0)) + # print("max_logits", new_max_logits) + # print("partial_losses", exp_logits * (logits-log_target)) + + else: + loss = loss * tl.exp(max_logits - new_max_logits) + tl.sum( + tl.where(mask, exp_logits * (logits - log_target), 0) + ) + max_logits = new_max_logits + if from_logits: + target_max_logits = target_new_max_logits + + # print("partial_loss", loss) + if not return_partial_loss: + loss = loss / sum_exp_logits - tl.log(sum_exp_logits) - max_logits + if from_logits: + loss = loss + tl.log(target_sum_exp_logits) + target_max_logits + + return loss, logits, target, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits + + +@triton_jit() +def triton_reverse_kl_forward_kernel_from_distribution( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + target_stride_0: tl_constexpr, + block_size: tl_constexpr, + loss_mask_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + partial_losses_ptr=None, + from_logits: tl_constexpr = True, + logits_scale_factor: tl_constexpr = 1.0, + target_logits_scale_factor: tl_constexpr = 1.0, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + target_ptr = target_ptr + block_idx * target_stride_0 + + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + tl.store(partial_losses_ptr + block_idx, 0) + return + + partial_loss, _, _, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits = ( + triton_reverse_kl_forward_from_distribution( + logits_ptr, + target_ptr, + n_cols=n_cols, + block_size=block_size, + from_logits=from_logits, + logits_scale_factor=logits_scale_factor, + target_logits_scale_factor=target_logits_scale_factor, + return_partial_loss=True, + ) + ) + if partial_losses_ptr is not None: + tl.store(partial_losses_ptr + block_idx, partial_loss) + if max_logits_ptr is not None: + tl.store(max_logits_ptr + block_idx, max_logits) + if sum_exp_logits_ptr is not None: + tl.store(sum_exp_logits_ptr + block_idx, sum_exp_logits) + + if target_max_logits_ptr is not None: + tl.store(target_max_logits_ptr + block_idx, target_max_logits) + if target_sum_exp_logits_ptr is not None: + tl.store(target_sum_exp_logits_ptr + block_idx, target_sum_exp_logits) + + +@triton_jit() +def triton_reverse_kl_forward_backward_kernel_from_distribution( + logits_ptr, + target_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + target_stride_0: tl_constexpr, + block_size: tl_constexpr, + loss_mask_ptr=None, + partial_losses_ptr=None, + losses_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + target_max_logits_ptr=None, + target_sum_exp_logits_ptr=None, + from_logits: tl_constexpr = True, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + logits_scale_factor: tl_constexpr = 1.0, + target_logits_scale_factor: tl_constexpr = 1.0, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + target_ptr = target_ptr + block_idx * target_stride_0 + + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + if max_logits_ptr is None or sum_exp_logits_ptr is None: + loss, logits, target, sum_exp_logits, max_logits, target_sum_exp_logits, target_max_logits = ( + triton_reverse_kl_forward_from_distribution( + logits_ptr, + target_ptr, + n_cols=n_cols, + block_size=block_size, + from_logits=from_logits, + logits_scale_factor=logits_scale_factor, + target_logits_scale_factor=target_logits_scale_factor, + ) + ) + else: + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + if from_logits: + target_max_logits = tl.load(target_max_logits_ptr + block_idx) + target_sum_exp_logits = tl.load(target_sum_exp_logits_ptr + block_idx) + + # print("sum_exp_logits", sum_exp_logits) + # print("max_logits", max_logits) + + # if from_logits: + # print("target_sum_exp_logits", target_sum_exp_logits) + # print("target_max_logits", target_max_logits) + + if losses_ptr is not None: + if partial_losses_ptr is not None: + loss = tl.load(partial_losses_ptr + block_idx) + # print("partial_loss", loss) + loss = loss / sum_exp_logits - tl.log(sum_exp_logits) - max_logits + if from_logits: + loss = loss + tl.log(target_sum_exp_logits) + target_max_logits + tl.store(losses_ptr + block_idx, loss) + + if grad_losses is not None: + if logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if from_logits and target_logits_scale_factor != 1.0: + target *= target_logits_scale_factor + if from_logits: + target_log_probability = target - target_max_logits - tl.log(target_sum_exp_logits) + else: + target_log_probability = tl.log(target) + + predicted_probability = tl.exp(logits - max_logits) / sum_exp_logits + predicted_log_probability = logits - max_logits - tl.log(sum_exp_logits) + grad_logits = ( + grad_losses * (predicted_log_probability - target_log_probability - loss) * predicted_probability + ) + tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + + @torch.compile def _rescale_sum_exp_logits( sum_exp_logits: torch.Tensor, @@ -389,12 +661,11 @@ def _cross_entropy_loss_from_labels( @torch.compile def _rescale_predicted_logits( predicted_logits: torch.Tensor, - target_sum_exp_logits: torch.Tensor, local_target_max_logits: torch.Tensor, target_max_logits: torch.Tensor, ): # We skipped the division by `target_sum_exp_logits` in the triton kernel so we do it here. - return predicted_logits * torch.exp(local_target_max_logits - target_max_logits) / target_sum_exp_logits + return predicted_logits * torch.exp(local_target_max_logits - target_max_logits) @torch.compile @@ -415,7 +686,7 @@ def _cross_entropy_loss_from_distribution( return per_sample_losses.mean() -def triton_cross_entropy_forward_backward( +def triton_entropy_loss_forward_backward( logits: torch.Tensor, target: torch.Tensor, loss_mask: torch.Tensor | None, @@ -434,8 +705,6 @@ def triton_cross_entropy_forward_backward( Compared to a standard pytorch implementation, this reduces memory usage (of logits) by 3x and memory I/O by 5x. TODO: Better handling of `grad_output = None` """ - assert TritonConfig.TRITON_ENABLED - Assert.incl(entropy_loss_type, (EntropyLossType.cross_entropy, EntropyLossType.forward_kl)) # TODO: Improve assumptions. assert logits.is_contiguous() assert target.is_contiguous() @@ -465,6 +734,7 @@ def triton_cross_entropy_forward_backward( } ) if target_format == TargetFormat.labels: + assert entropy_loss_type != EntropyLossType.reverse_kl if group is None: losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( @@ -476,21 +746,21 @@ def triton_cross_entropy_forward_backward( ) loss = losses.mean() else: - predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) - local_max_logits = torch.empty_like(predicted_logits) - sum_exp_logits = torch.empty_like(predicted_logits) + partial_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(partial_losses) + sum_exp_logits = torch.empty_like(partial_losses) triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( logits, target, max_logits_ptr=local_max_logits, sum_exp_logits_ptr=sum_exp_logits, - predicted_logits_ptr=predicted_logits, + predicted_logits_ptr=partial_losses, col_min=n_cols * group.rank(), **kwargs, ) max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) - torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) - loss = _cross_entropy_loss_from_labels(predicted_logits, target, sum_exp_logits, max_logits) + torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) + loss = _cross_entropy_loss_from_labels(partial_losses, target, sum_exp_logits, max_logits) if grad_output is not None: triton_cross_entropy_forward_backward_from_labels_kernel[(n_rows,)]( logits, @@ -502,11 +772,17 @@ def triton_cross_entropy_forward_backward( **backward_kwargs, ) else: + if loss_mask is not None: + assert loss_mask.is_contiguous() + if entropy_loss_type == EntropyLossType.reverse_kl: + kernel = triton_reverse_kl_forward_backward_kernel_from_distribution + else: + kernel = triton_cross_entropy_from_distribution_forward_backward_kernel + kwargs["return_kl_loss"] = entropy_loss_type == EntropyLossType.forward_kl + if group is None: losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) - if loss_mask is not None: - assert loss_mask.is_contiguous() - triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( + kernel[(n_rows,)]( logits, target, loss_mask_ptr=loss_mask, @@ -518,22 +794,27 @@ def triton_cross_entropy_forward_backward( target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, - return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, **kwargs, **backward_kwargs, ) loss = losses.mean() else: - predicted_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) - local_max_logits = torch.empty_like(predicted_logits) - sum_exp_logits = torch.empty_like(predicted_logits) + partial_losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + local_max_logits = torch.empty_like(partial_losses) + sum_exp_logits = torch.empty_like(partial_losses) if target_format == TargetFormat.logits: - local_target_max_logits = torch.empty_like(predicted_logits) - target_sum_exp_logits = torch.empty_like(predicted_logits) + local_target_max_logits = torch.empty_like(partial_losses) + target_sum_exp_logits = torch.empty_like(partial_losses) else: local_target_max_logits = target_sum_exp_logits = None - triton_cross_entropy_from_distribution_forward_parallel_kernel[(n_rows,)]( + forward_kernel = ( + triton_reverse_kl_forward_kernel_from_distribution + if entropy_loss_type == EntropyLossType.reverse_kl + else triton_cross_entropy_from_distribution_forward_parallel_kernel + ) + + forward_kernel[(n_rows,)]( logits, target, loss_mask_ptr=loss_mask, @@ -541,11 +822,10 @@ def triton_cross_entropy_forward_backward( sum_exp_logits_ptr=sum_exp_logits, target_max_logits_ptr=local_target_max_logits, target_sum_exp_logits_ptr=target_sum_exp_logits, - predicted_logits_ptr=predicted_logits, + partial_losses_ptr=partial_losses, target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, - return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl, **kwargs, **backward_kwargs, ) @@ -554,24 +834,17 @@ def triton_cross_entropy_forward_backward( target_max_logits, target_sum_exp_logits = _parallel_sum_exp_logits( target_sum_exp_logits, local_target_max_logits, group ) - predicted_logits = _rescale_predicted_logits( - predicted_logits, target_sum_exp_logits, local_target_max_logits, target_max_logits - ) + if entropy_loss_type != EntropyLossType.reverse_kl: + partial_losses = _rescale_predicted_logits( + partial_losses, local_target_max_logits, target_max_logits + ) else: target_max_logits = None - torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=group) - - loss = _cross_entropy_loss_from_distribution( - predicted_logits, - loss_mask, - sum_exp_logits, - max_logits, - target_sum_exp_logits=target_sum_exp_logits, - target_max_logits=target_max_logits, - return_kl_loss=entropy_loss_type == EntropyLossType.forward_kl - and target_format == TargetFormat.logits, - ) - triton_cross_entropy_from_distribution_forward_backward_kernel[(n_rows,)]( + if entropy_loss_type == EntropyLossType.reverse_kl: + partial_losses = _rescale_predicted_logits(partial_losses, local_max_logits, max_logits) + torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) + + kernel[(n_rows,)]( logits, target, loss_mask_ptr=loss_mask, @@ -579,11 +852,13 @@ def triton_cross_entropy_forward_backward( sum_exp_logits_ptr=sum_exp_logits, target_max_logits_ptr=target_max_logits, target_sum_exp_logits_ptr=target_sum_exp_logits, - predicted_logits_ptr=predicted_logits, + partial_losses_ptr=partial_losses, + losses_ptr=partial_losses, target_stride_0=target.stride(-2), target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, **kwargs, **backward_kwargs, ) + loss = partial_losses.mean() return loss, grad_logits diff --git a/fast_llm/functional/triton/mlp.py b/fast_llm/functional/triton/mlp.py index 286e7159a..7949faaf0 100644 --- a/fast_llm/functional/triton/mlp.py +++ b/fast_llm/functional/triton/mlp.py @@ -14,7 +14,7 @@ output_parallel_linear_forward, update_linear_gradients, ) -from fast_llm.functional.triton import tl, tl_constexpr, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton_jit from fast_llm.functional.triton.sparse_copy import ( SparseMap, copy_dense_to_sparse_backward, @@ -37,7 +37,7 @@ def triton_mlp_activation_forward_kernel( ): # TODO: Int64 ptr only if needed? row_idx = tl.program_id(0).to(tl.int64) - columns = tl.program_id(1) * block_size + tl.arange(0, block_size) + columns = tl.program_id(1) * block_size + tl_arange(0, block_size) output_offsets = n_cols * row_idx + columns input_offsets = 2 * n_cols * row_idx + columns if gated else output_offsets @@ -85,7 +85,7 @@ def triton_mlp_activation_backward_kernel( ): # TODO: Int64 ptr only if needed? row_idx = tl.program_id(0).to(tl.int64) - columns = tl.program_id(1) * block_size + tl.arange(0, block_size) + columns = tl.program_id(1) * block_size + tl_arange(0, block_size) output_offsets = n_cols * row_idx + columns input_offsets = 2 * n_cols * row_idx + columns if gated else output_offsets @@ -219,6 +219,7 @@ def mlp_forward( recompute_level: MLPRecomputeLevel = MLPRecomputeLevel.none, transposed_layer_2_weight: bool = False, sparse_map: SparseMap | None = None, + use_triton: bool | None = None, ) -> tuple[torch.Tensor, list[typing.Any] | None]: # Sparse copy input_shape = input_.shape @@ -235,7 +236,7 @@ def mlp_forward( input_ = None # Activation - if TritonConfig.TRITON_ENABLED and intermediate_1.device.type == "cuda": + if TritonConfig.enabled(intermediate_1.device, use_triton): intermediate_2, _ = triton_mlp_activation_forward(intermediate_1, gated, activation_type) else: do_grad = training and not recompute_level.recompute_activation @@ -287,6 +288,7 @@ def mlp_forward( transposed_layer_2_weight, sparse_map, input_shape, + use_triton, ] if training else None @@ -313,6 +315,7 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ transposed_layer_2_weight, sparse_map, input_shape, + use_triton, ) = context context.clear() @@ -344,7 +347,7 @@ def mlp_backward(grad_output: torch.Tensor, context: list[typing.Any]) -> tuple[ )[0] # Activation recomputation and/or backward - if TritonConfig.TRITON_ENABLED and grad_output.device.type == "cuda": + if TritonConfig.enabled(grad_output.device, use_triton): grad_intermediate_1, intermediate_2_ = triton_mlp_activation_backward( grad_intermediate_2, (intermediate_1, gated, activation_type), intermediate_2 is None ) diff --git a/fast_llm/functional/triton/normalization.py b/fast_llm/functional/triton/normalization.py index a018ad44b..9538a9275 100644 --- a/fast_llm/functional/triton/normalization.py +++ b/fast_llm/functional/triton/normalization.py @@ -4,7 +4,7 @@ from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, tl_full, triton, triton_jit from fast_llm.tensor import param_get_and_unset_is_zero @@ -23,7 +23,7 @@ def triton_normalization_forward_kernel( ): # Program dimensions row = tl.program_id(0).to(tl.int64) - cols = tl.arange(0, block_size) + cols = tl_arange(0, block_size) mask = cols < n_cols offsets = row * n_cols + cols @@ -75,10 +75,10 @@ def triton_normalization_backward_kernel_1( block_size_row: tl_constexpr, ): # row_start = tl.program_id(0)*block_size_row - rows = tl.program_id(0) * block_size_row + tl.arange(0, block_size_row)[:, None] + rows = tl.program_id(0) * block_size_row + tl_arange(0, block_size_row)[:, None] row_mask = rows < n_rows - cols = tl.arange(0, block_size)[None, :] + cols = tl_arange(0, block_size)[None, :] col_mask = cols < n_cols mask = col_mask & row_mask @@ -140,15 +140,15 @@ def triton_normalization_backward_kernel_2( block_size_n: tl_constexpr, ): pid = tl.program_id(0) - cols = pid * block_size_n + tl.arange(0, block_size_n) - grad_weight_partial_sum = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) + cols = pid * block_size_n + tl_arange(0, block_size_n) + grad_weight_partial_sum = tl_full((block_size_m, block_size_n), 0, dtype=tl.float32) if has_bias: - grad_bias_partial_sum = tl.zeros((block_size_m, block_size_n), dtype=tl.float32) + grad_bias_partial_sum = tl_full((block_size_m, block_size_n), 0, dtype=tl.float32) col_mask = cols < n_cols # Partial sums. for i in range(0, m, block_size_m): - rows = i + tl.arange(0, block_size_m) + rows = i + tl_arange(0, block_size_m) mask = (rows[:, None] < m) & (cols[None, :] < n_cols) offsets = rows[:, None] * n_cols + cols[None, :] grad_weight_partial_sum += tl.load(grad_weight_partial_ptr + offsets, mask=mask, other=0.0) diff --git a/fast_llm/functional/triton/pointwise.py b/fast_llm/functional/triton/pointwise.py index 22676ae1a..44bb805f2 100644 --- a/fast_llm/functional/triton/pointwise.py +++ b/fast_llm/functional/triton/pointwise.py @@ -7,7 +7,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, tl_full, triton, triton_jit @triton_jit() @@ -19,7 +19,7 @@ def triton_copy_kernel( ): # TODO: Int64 ptr only if needed? block_start = tl.program_id(axis=0).to(tl.int64) * block_size - offsets = block_start + tl.arange(0, block_size) + offsets = block_start + tl_arange(0, block_size) mask = offsets < numel input_ = tl.load(input_ptr + offsets, mask=mask) tl.store(out_ptr + offsets, input_, mask=mask) @@ -28,11 +28,12 @@ def triton_copy_kernel( def triton_copy( input_: torch.Tensor, out: torch.Tensor, + use_triton: bool | None = None, ) -> torch.Tensor: """ A triton implementation of tensor copying (`torch.Tensor.copy_()`). """ - if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": + if not TritonConfig.enabled(input_.device, use_triton): return out.copy_(input_) # TODO: Improve assumptions. assert input_.is_contiguous() @@ -53,19 +54,20 @@ def triton_fill_kernel( ): # TODO: Int64 ptr only if needed? block_start = tl.program_id(axis=0).to(tl.int64) * block_size - offsets = block_start + tl.arange(0, block_size) + offsets = block_start + tl_arange(0, block_size) mask = offsets < numel - tl.store(input_ptr + offsets, tl.full((block_size,), value, dtype), mask=mask) + tl.store(input_ptr + offsets, tl_full((block_size,), value, dtype), mask=mask) def triton_fill( input_: torch.Tensor, value: float | int, + use_triton: bool | None = None, ) -> torch.Tensor: """ A faster triton implementation of tensor copying (`torch.Tensor.fill_()`). """ - if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": + if not TritonConfig.enabled(input_.device, use_triton): return input_.fill_(value) # TODO: Improve assumptions. assert input_.is_contiguous() @@ -91,7 +93,7 @@ def triton_add_kernel( ): # TODO: Int64 ptr only if needed? block_start = tl.program_id(axis=0).to(tl.int64) * block_size - offsets = block_start + tl.arange(0, block_size) + offsets = block_start + tl_arange(0, block_size) mask = offsets < numel input_ = tl.load(input_ptr + offsets, mask=mask) other = tl.load(other_ptr + offsets, mask=mask) @@ -102,11 +104,12 @@ def triton_add( input_: torch.Tensor, other: torch.Tensor, out: torch.Tensor | None = None, + use_triton: bool | None = None, ) -> torch.Tensor: """ A faster triton implementation of tensor addition (`torch.Tensor.add()`). """ - if not TritonConfig.TRITON_ENABLED or input_.device.type != "cuda": + if not TritonConfig.enabled(input_.device, use_triton): return torch.add(input_, other, out=out) # TODO: Improve assumptions. assert input_.is_contiguous() diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index c510925c6..2c93776af 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -2,7 +2,7 @@ from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit from fast_llm.utils import div @@ -25,8 +25,8 @@ def triton_rotary_kernel( pid_1 = tl.program_id(axis=1) # Head index position_id = pid_0 % seq_len - offsets = tl.arange(0, rotary_block_size) - head_offsets = pid_1 * head_block_size + tl.arange(0, head_block_size)[:, None] + offsets = tl_arange(0, rotary_block_size) + head_offsets = pid_1 * head_block_size + tl_arange(0, head_block_size)[:, None] input_offsets = stride_0 * (pid_0 // seq_len) + stride_1 * position_id + stride_2 * head_offsets + offsets[None, :] input_re_ptr = input_ptr + input_offsets input_im_ptr = input_re_ptr + rotary_dim diff --git a/fast_llm/functional/triton/sparse_copy.py b/fast_llm/functional/triton/sparse_copy.py index 7c803689c..e68692d9c 100644 --- a/fast_llm/functional/triton/sparse_copy.py +++ b/fast_llm/functional/triton/sparse_copy.py @@ -5,7 +5,7 @@ from fast_llm.engine.config_utils.data_type import DataType from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, TritonConfig -from fast_llm.functional.triton import tl, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit @dataclasses.dataclass() @@ -36,7 +36,7 @@ def copy_dense_to_sparse_kernel( block_size: tl_constexpr, ): dense_row = tl.program_id(0) - offsets = tl.arange(0, block_size) + block_size * tl.program_id(1) + offsets = tl_arange(0, block_size) + block_size * tl.program_id(1) mask = None if num_columns % block_size == 0 else offsets < num_columns out = tl.load(input_ptr + dense_row * num_columns + offsets, mask=mask) # Write to each expert. @@ -78,7 +78,7 @@ def copy_sparse_to_dense_kernel( block_size: tl_constexpr, ): dense_row = tl.program_id(0) - offsets = tl.arange(0, block_size) + block_size * tl.program_id(1) + offsets = tl_arange(0, block_size) + block_size * tl.program_id(1) mask = None if num_columns % block_size == 0 else offsets < num_columns out = tl.zeros((block_size,), tl.float32) # Sum over experts. @@ -125,7 +125,7 @@ def copy_sparse_to_dense_grad_score_kernel( grad_output_ptr += dense_row * num_columns input_ptr += sparse_row * num_columns - offsets = tl.arange(0, block_size) + offsets = tl_arange(0, block_size) if num_columns % block_size == 0: grad_scores = tl.load(input_ptr + offsets).to(tl.float32) * tl.load(grad_output_ptr + offsets).to(tl.float32) @@ -216,8 +216,8 @@ def sparse_map_kernel( we use a one-hot representation to get the quantities we want. TODO: Next triton release will support tl.histogram, maybe argsort. """ - block_range = tl.arange(0, block_size) - expert_range = tl.arange(0, block_size_expert) + block_range = tl_arange(0, block_size) + expert_range = tl_arange(0, block_size_expert) expert_mask = None if block_size_expert == num_experts else expert_range < num_experts if num_sparse_rows >= block_size: @@ -256,7 +256,7 @@ def sparse_map_kernel( if sparse_rows_ptr is not None: # Assign a new unique index to each row so that it lies in the range (expert_begin, expert_end) # for its assigned expert. - block_range = tl.arange(0, block_size) + block_range = tl_arange(0, block_size) for i in range(tl.cdiv(num_sparse_rows, block_size)): if num_sparse_rows % block_size == 0: mask = None @@ -307,7 +307,8 @@ def get_sparse_map( num_rows_unpadded = num_rows_dense * num_experts_per_token max_rows = (num_rows_unpadded + num_experts * pad_to_multiple) // pad_to_multiple * pad_to_multiple dtype = torch.int16 if max_rows < 32768 else torch.int32 - if (use_triton is None and TritonConfig.TRITON_ENABLED) or use_triton: + + if TritonConfig.enabled(top_experts.device, use_triton): expert_ends, expert_pad_begins = top_experts.new_empty((2 * num_experts,), dtype=dtype).chunk(2) sparse_rows = expert_ends.new_empty(num_rows_dense, num_experts_per_token) sparse_map_kernel[(triton.cdiv(num_rows_dense, block_size),)]( diff --git a/fast_llm/functional/triton/sparse_linear.py b/fast_llm/functional/triton/sparse_linear.py index ae46655ea..15af789d7 100644 --- a/fast_llm/functional/triton/sparse_linear.py +++ b/fast_llm/functional/triton/sparse_linear.py @@ -2,7 +2,7 @@ import torch -from fast_llm.functional.triton import TritonConfig, tl, tl_constexpr, triton, triton_autotune, triton_jit +from fast_llm.functional.triton import TritonConfig, tl, tl_arange, tl_constexpr, triton, triton_autotune, triton_jit from fast_llm.functional.triton.sparse_copy import SparseMap from fast_llm.utils import Assert, div @@ -99,9 +99,9 @@ def dense_matmul_kernel( col_offset = pid_col * block_size_col # Pointers - row_range = tl.arange(0, block_size_row)[:, None] + row_offset - col_range = tl.arange(0, block_size_col)[None, :] + col_offset - inner_range = tl.arange(0, block_size_inner) + row_range = tl_arange(0, block_size_row)[:, None] + row_offset + col_range = tl_arange(0, block_size_col)[None, :] + col_offset + inner_range = tl_arange(0, block_size_inner) lhs_ptr += row_range * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += inner_range[:, None] * rhs_stride_inner + col_range * rhs_stride_col out_ptr += row_range * out_stride_row + col_range * out_stride_col @@ -228,7 +228,7 @@ def output_sparse_matmul_kernel( # Grid offsets row_offset = pid_row * block_size_row col_sparse_offset = pid_col * block_size_col - sparse_range = tl.arange(0, padded_sparse_dim) + sparse_range = tl_arange(0, padded_sparse_dim) expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa if sparse_index == sparse_dim: @@ -236,9 +236,9 @@ def output_sparse_matmul_kernel( col_dense_offset = col_sparse_offset + sparse_index * col_sparse_dim # Pointers - row_range = tl.arange(0, block_size_row)[:, None] - col_range = tl.arange(0, block_size_col)[None, :] - inner_range = tl.arange(0, block_size_inner) + row_range = tl_arange(0, block_size_row)[:, None] + col_range = tl_arange(0, block_size_col)[None, :] + inner_range = tl_arange(0, block_size_inner) lhs_ptr += (row_offset + row_range) * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += inner_range[:, None] * rhs_stride_inner + (col_dense_offset + col_range) * rhs_stride_col out_ptr += (row_offset + row_range) * out_stride_row + (col_sparse_offset + col_range) * out_stride_col @@ -351,7 +351,7 @@ def input_inner_sparse_matmul_kernel( # Grid offsets row_offset = pid_row * block_size_row - sparse_range = tl.arange(0, padded_sparse_dim) + sparse_range = tl_arange(0, padded_sparse_dim) expert_ends = tl.load(expert_ends_ptr + sparse_range, mask=sparse_range < sparse_dim, other=row_dim) sparse_index = tl.sum((expert_ends <= row_offset).to(tl.int64)) # noqa if sparse_index == sparse_dim: @@ -360,9 +360,9 @@ def input_inner_sparse_matmul_kernel( col_offset = pid_col * block_size_col # Pointers - row_range = tl.arange(0, block_size_row)[:, None] - col_range = tl.arange(0, block_size_col)[None, :] - inner_range = tl.arange(0, block_size_inner) + row_range = tl_arange(0, block_size_row)[:, None] + col_range = tl_arange(0, block_size_col)[None, :] + inner_range = tl_arange(0, block_size_inner) lhs_ptr += (row_offset + row_range) * lhs_stride_row + inner_range[None, :] * lhs_stride_inner rhs_ptr += (inner_dense_offset + inner_range[:, None]) * rhs_stride_inner + ( col_offset + col_range @@ -485,9 +485,9 @@ def input_row_sparse_matmul_kernel( inner_offset = (inner_begin // block_size_inner) * block_size_inner # Pointers - row_range = tl.arange(0, block_size_row)[:, None] - col_range = tl.arange(0, block_size_col)[None, :] - inner_range = tl.arange(0, block_size_inner) + inner_offset + row_range = tl_arange(0, block_size_row)[:, None] + col_range = tl_arange(0, block_size_col)[None, :] + inner_range = tl_arange(0, block_size_inner) + inner_offset lhs_ptr += (row_sparse_offset + row_range) * lhs_stride_row rhs_ptr += (col_offset + col_range) * rhs_stride_col out_ptr += (row_dense_offset + row_range) * out_stride_row + (col_offset + col_range) * out_stride_col diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 626a8fde6..40baf2009 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -1,10 +1,8 @@ import enum import logging import typing -import warnings from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none -from fast_llm.functional.config import TritonConfig from fast_llm.layers.attention.rotary.config import RotaryConfig from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.linear.config import AffineLinearConfig @@ -132,9 +130,6 @@ class AttentionConfig(MixerConfig): def _validate(self) -> None: super()._validate() - if not TritonConfig.TRITON_ENABLED: - warnings.warn("Triton is disabled, but triton rotary kernel will be used anyway.") - Assert.multiple(self.heads, self.head_groups) if not self.causal: diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 304f96b83..307256a72 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -100,11 +100,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: - rotary_fn = ( - triton_rotary_autograd_ - if TritonConfig.TRITON_ENABLED and query.device.type == "cuda" - else rotary_embeddings_real - ) + rotary_fn = triton_rotary_autograd_ if TritonConfig.enabled(query.device) else rotary_embeddings_real query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key @@ -238,11 +234,7 @@ def preprocess(self, kwargs: dict[str, typing.Any]) -> None: def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: - rotary_fn = ( - triton_rotary_autograd_ - if TritonConfig.TRITON_ENABLED and query.device.type == "cuda" - else rotary_embeddings_real - ) + rotary_fn = triton_rotary_autograd_ if TritonConfig.enabled(query.device) else rotary_embeddings_real query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key diff --git a/fast_llm/layers/common/normalization/normalization.py b/fast_llm/layers/common/normalization/normalization.py index 55e62af22..6fe1ea519 100644 --- a/fast_llm/layers/common/normalization/normalization.py +++ b/fast_llm/layers/common/normalization/normalization.py @@ -190,7 +190,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | and not self._config.zero_centered ): implementation = NormalizationImplementation.fast - elif (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: + elif TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: log_main_rank("Fast layer norm unavailable, using backup triton implementation.") implementation = NormalizationImplementation.triton elif _fused_normalization_available: @@ -259,7 +259,7 @@ def __init__(self, config: ConfigType, hidden_dim: TensorDim, lr_scale: float | assert not hidden_dim.is_parallel implementation = self._config.implementation if implementation == NormalizationImplementation.auto: - if (TritonConfig.TRITON_ENABLED and torch.cuda.is_available()) or self._config.zero_centered: + if TritonConfig.enabled(torch.device("cuda")) or self._config.zero_centered: implementation = NormalizationImplementation.triton elif _fused_normalization_available: log_main_rank("Triton RMS norm unavailable, using fused implementation.") diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 882963ce9..88c86c8aa 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -44,7 +44,9 @@ def __init__( self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) intermediate_1_dim, self._intermediate_2_dim = self._get_intermediate_dims() - self._activation_fn = triton_mlp_activation_autograd if TritonConfig.TRITON_ENABLED else torch_mlp_activation + self._activation_fn = ( + triton_mlp_activation_autograd if TritonConfig.enabled(torch.device("cuda")) else torch_mlp_activation + ) # So both layers' weights have shape (num_experts [* gate_up] * ffn, hidden_size) self.layer_1 = self._config.layer_1.get_layer( diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index f531a1d46..d636d6af7 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -2,7 +2,7 @@ from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType +from fast_llm.functional.config import EntropyLossType from fast_llm.layers.block.config import BlockKwargs from fast_llm.utils import Assert @@ -77,11 +77,10 @@ class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): desc="Type of loss to use.", hint=FieldHint.core, ) - - implementation: EntropyLossImplementation = Field( - default=EntropyLossImplementation.auto, - desc="Loss implementation.", - hint=FieldHint.performance, + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, ) @property @@ -100,11 +99,6 @@ class LanguageModelDistillationLossConfig(LanguageModelLossConfig): desc="Type of loss to use.", hint=FieldHint.core, ) - implementation: EntropyLossImplementation = Field( - default=EntropyLossImplementation.auto, - desc="Loss implementation.", - hint=FieldHint.performance, - ) reference_model: str = Field( default="teacher", desc="Name of the reference model for knowledge distillation.", @@ -116,6 +110,11 @@ class LanguageModelDistillationLossConfig(LanguageModelLossConfig): desc="Temperature for teacher softmax.", valid=check_field(Assert.gt, 0.0), ) + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) @property def loss_class(self) -> "type[LanguageModelDistillationLoss]": diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 351aa210b..25e0c19b6 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -4,7 +4,7 @@ from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward -from fast_llm.functional.triton.cross_entropy import triton_cross_entropy_forward_backward +from fast_llm.functional.triton.cross_entropy import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.config import ( LanguageModelDistillationLossConfig, LanguageModelLabelEntropyLossConfig, @@ -13,32 +13,9 @@ from fast_llm.utils import Assert -def _get_implementation( - default: EntropyLossImplementation = EntropyLossImplementation.auto, - loss_type: EntropyLossType = EntropyLossType.cross_entropy, - vocab_parallel: bool = False, -) -> EntropyLossImplementation: - # Vocab parallel requires fused. - if vocab_parallel: - assert default in (EntropyLossImplementation.auto, EntropyLossImplementation.fused) - return EntropyLossImplementation.fused - - # Triton only available for cross_entropy - if TritonConfig.TRITON_ENABLED and torch.cuda.is_available() and loss_type == EntropyLossType.cross_entropy: - return EntropyLossImplementation.triton if default == EntropyLossImplementation.auto else default - else: - assert default != EntropyLossImplementation.triton - - # Otherwise, use fused. - return EntropyLossImplementation.fused if default == EntropyLossImplementation.auto else default - - class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._implementation = _get_implementation( - self._config.implementation, self._config.loss_type, self._vocab_parallel - ) def forward_backward( self, @@ -52,10 +29,10 @@ def forward_backward( None, # Labels are already masked grad_output=self._get_grad_output(kwargs), group=self._parallel_dim.group if self._vocab_parallel else None, - implementation=self._implementation, logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.labels, entropy_loss_type=self._config.loss_type, + use_triton=self._config.use_triton, ) @@ -65,10 +42,6 @@ def __init__(self, *args, **kwargs): if self._prediction_distance > 0: raise NotImplementedError() - self._implementation = _get_implementation( - self._config.implementation, self._config.loss_type, self._vocab_parallel - ) - def forward_backward( self, logits: "torch.Tensor", @@ -81,17 +54,17 @@ def forward_backward( self._get_loss_mask(kwargs, split_index), grad_output=self._get_grad_output(kwargs), group=self._parallel_dim.group if self._vocab_parallel else None, - implementation=self._implementation, logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.logits, entropy_loss_type=self._config.loss_type, + use_triton=self._config.use_triton, ) _ENTROPY_LOSS_IMPLEMENTATIONS = { EntropyLossImplementation.torch: torch_entropy_loss_forward_backward, EntropyLossImplementation.fused: fused_entropy_loss_forward_backward, - EntropyLossImplementation.triton: triton_cross_entropy_forward_backward, + EntropyLossImplementation.triton: triton_entropy_loss_forward_backward, } @@ -101,11 +74,11 @@ def entropy_loss_forward_backward( loss_mask: torch.Tensor | None, # (*batch,) grad_output: float | None, group: torch.distributed.ProcessGroup | None = None, - implementation: EntropyLossImplementation = EntropyLossImplementation.fused, logits_scale_factor: float = 1.0, temperature: float = 1.0, target_format: TargetFormat = TargetFormat.labels, entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, + use_triton: bool | None = None, **kwargs, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ @@ -114,6 +87,7 @@ def entropy_loss_forward_backward( It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, which is faster and has a relatively small memory overhead. """ + if target_format == TargetFormat.labels: Assert.eq(target.shape, logits.shape[:-1]) Assert.eq(target.dtype, torch.int64) @@ -123,7 +97,11 @@ def entropy_loss_forward_backward( assert target.dtype.is_floating_point, target.dtype if loss_mask is not None: Assert.eq(loss_mask.shape, logits.shape[:-1]) - return _ENTROPY_LOSS_IMPLEMENTATIONS[implementation]( + return ( + triton_entropy_loss_forward_backward + if TritonConfig.enabled(logits.device, use_triton) + else fused_entropy_loss_forward_backward + )( logits, target, loss_mask, diff --git a/tests/conftest.py b/tests/conftest.py index 4f7d7bad0..23fc58b16 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -279,3 +279,8 @@ def pytest_xdist_make_scheduler(config, log): # Always use grouped load balancing to handle dependencies, and make it work with `-n`. assert config.getvalue("dist") == "load" return xdist.scheduler.LoadGroupScheduling(config, log) + + +@pytest.fixture(scope="session") +def testing_device() -> torch.device: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") diff --git a/tests/functional/test_functional.py b/tests/functional/test_functional.py index 6471a516f..7980f05bf 100644 --- a/tests/functional/test_functional.py +++ b/tests/functional/test_functional.py @@ -2,6 +2,7 @@ import torch from fast_llm.functional.config import ActivationType, MLPRecomputeLevel +from fast_llm.functional.triton import triton_available from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped, torch_mlp_activation from fast_llm.functional.triton.sparse_copy import get_sparse_map from fast_llm.utils import Assert @@ -19,12 +20,12 @@ def _get_target_log_probability_for_spans(log_probabilities: torch.Tensor, spans @pytest.mark.parametrize( "activation", [ActivationType.gelu, ActivationType.silu, ActivationType.relu, ActivationType.squared_relu] ) -def test_mlp_recomputation(gated, activation): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - tokens = 1024 - hidden_size = 2048 - intermediate_size = 4096 - std = 1 / 64 +def test_mlp_recomputation(gated, activation, testing_device): + device = torch.device(testing_device) + tokens = 64 + hidden_size = 128 + intermediate_size = 256 + std = 1 / 16 input_ = torch.randn(tokens, hidden_size, device=device, requires_grad=True) output_grad = torch.randn(tokens, hidden_size, device=device, requires_grad=True) weight_1 = torch.normal(0, std, (intermediate_size * (gated + 1), hidden_size), device=device, requires_grad=True) @@ -53,7 +54,20 @@ def test_mlp_recomputation(gated, activation): param.grad = None param.grad_buffer = torch.empty_like(param) param.param_grad_is_zero = True - output = mlp_autograd(input_, None, *params, gated, activation, None, False, True, recompute_level, True) + output = mlp_autograd( + input_, + None, + *params, + gated, + activation, + None, + False, + True, + recompute_level, + True, + None, + triton_available and torch.cuda.is_available(), + ) output.backward(output_grad) if i == 0: Assert.rms_close(output, output_ref, 1e-5) @@ -74,8 +88,8 @@ def test_mlp_recomputation(gated, activation): # Takes ~6s, much more if it needs to compile, reducing the hidden size doesn't help. @pytest.mark.slow @pytest.mark.skip("Dropless MoE is broken") -def test_dropless_mlp(): - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def test_dropless_mlp(testing_device): + device = torch.device(testing_device) num_experts = 4 experts_per_token = 4 tokens = 256 diff --git a/tests/functional/test_sparse_matmul.py b/tests/functional/test_sparse_matmul.py index 899dad967..0ebf9c5a5 100644 --- a/tests/functional/test_sparse_matmul.py +++ b/tests/functional/test_sparse_matmul.py @@ -12,7 +12,7 @@ output_sparse_matmul, ) from fast_llm.utils import Assert -from tests.utils.utils import requires_cuda +from tests.utils.utils import requires_triton @dataclasses.dataclass @@ -46,12 +46,11 @@ def sparse_dim_expanded(self) -> int: def num_experts(self) -> int: return len(self.expert_begins) - @functools.cached_property - def sparse_map(self) -> SparseMap: + def get_sparse_map(self, device: torch.device) -> SparseMap: return SparseMap( num_experts=self.num_experts, - expert_ends=torch.tensor(self.expert_ends, device="cuda"), - expert_pad_begins=torch.tensor(self.expert_pad_begins, device="cuda"), + expert_ends=torch.tensor(self.expert_ends, device=device), + expert_pad_begins=torch.tensor(self.expert_pad_begins, device=device), num_rows=self.expert_ends[-1], # Not needed sparse_rows=None, @@ -60,8 +59,8 @@ def sparse_map(self) -> SparseMap: num_experts_per_token=None, ) - def normal(self, dim_0: int, dim_1: int) -> torch.Tensor: - return torch.normal(0, self.std, (dim_0, dim_1), device="cuda") + def normal(self, dim_0: int, dim_1: int, device: torch.device) -> torch.Tensor: + return torch.normal(0, self.std, (dim_0, dim_1), device=device) _SPARSE_TEST_DATAS = ( @@ -80,28 +79,28 @@ def normal(self, dim_0: int, dim_1: int) -> torch.Tensor: ) -@requires_cuda +@requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_dense_matmul(sparse_test_data): - lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) - rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim) +def test_dense_matmul(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim, testing_device) output = dense_matmul(lhs, rhs) output_ref = torch.matmul(lhs, rhs) Assert.rms_close(output, output_ref, 1e-3) -@requires_cuda +@requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_output_sparse_matmul(sparse_test_data): - lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) - rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded) +def test_output_sparse_matmul(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.dense_dim, sparse_test_data.sparse_dim_expanded, testing_device) # Randomly initialize the output to ensure padded values have no effect. - out = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim) - output = output_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map, out) + out = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim, testing_device) + output = output_sparse_matmul(lhs, rhs, sparse_test_data.get_sparse_map(testing_device), out) output_ref = torch.zeros_like(output) for i in range(sparse_test_data.num_experts): @@ -114,14 +113,14 @@ def test_output_sparse_matmul(sparse_test_data): Assert.rms_close(output, output_ref, 1e-3) -@requires_cuda +@requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_input_inner_sparse_matmul(sparse_test_data): - lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim) - rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim) +def test_input_inner_sparse_matmul(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.sparse_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.sparse_dim_expanded, sparse_test_data.dense_dim, testing_device) - output = input_inner_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map) + output = input_inner_sparse_matmul(lhs, rhs, sparse_test_data.get_sparse_map(testing_device)) output_ref = torch.zeros_like(output) for i in range(sparse_test_data.num_experts): @@ -134,14 +133,14 @@ def test_input_inner_sparse_matmul(sparse_test_data): Assert.rms_close(output, output_ref, 1e-3) -@requires_cuda +@requires_triton @pytest.mark.slow @pytest.mark.parametrize("sparse_test_data", _SPARSE_TEST_DATAS) -def test_input_row_sparse_matmul(sparse_test_data): - lhs = sparse_test_data.normal(sparse_test_data.sparse_dim, sparse_test_data.token_dim) - rhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim) +def test_input_row_sparse_matmul(sparse_test_data, testing_device): + lhs = sparse_test_data.normal(sparse_test_data.sparse_dim, sparse_test_data.token_dim, testing_device) + rhs = sparse_test_data.normal(sparse_test_data.token_dim, sparse_test_data.dense_dim, testing_device) - output = input_row_sparse_matmul(lhs, rhs, sparse_test_data.sparse_map) + output = input_row_sparse_matmul(lhs, rhs, sparse_test_data.get_sparse_map(testing_device)) output_ref = torch.zeros_like(output) for i in range(sparse_test_data.num_experts): diff --git a/tests/functional/test_triton_kernels.py b/tests/functional/test_triton_kernels.py index 79817bb03..2886ab14e 100644 --- a/tests/functional/test_triton_kernels.py +++ b/tests/functional/test_triton_kernels.py @@ -2,7 +2,7 @@ import torch from fast_llm.engine.config_utils.tensor_dim import TensorDim -from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, ActivationType, TritonConfig +from fast_llm.functional.config import MAX_DROPLESS_BLOCK_SIZE_ROW, ActivationType from fast_llm.functional.triton.adam import triton_adam from fast_llm.functional.triton.mlp import ( torch_mlp_activation, @@ -25,71 +25,66 @@ rotary_embeddings_real, ) from fast_llm.utils import Assert, rms_diff -from tests.utils.utils import requires_cuda +from tests.utils.utils import requires_cuda, requires_triton -@requires_cuda -def test_triton_fill(): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(425, 549, dtype=torch.bfloat16, device="cuda") - triton_fill(x, 32) +@requires_triton +def test_triton_fill(testing_device): + x = torch.randn(425, 549, dtype=torch.float16, device=testing_device) + triton_fill(x, 32, use_triton=True) assert x.min().item() == x.max().item() == 32 -@requires_cuda -def test_triton_copy(): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(7563, dtype=torch.bfloat16, device="cuda") +@requires_triton +def test_triton_copy(testing_device): + x = torch.randn(7563, dtype=torch.float32, device=testing_device).to(torch.float16) x1 = x.clone() y = torch.zeros_like(x) Assert.all_different(x, y) - triton_copy(x, y) + triton_copy(x, y, use_triton=True) Assert.all_equal(x, y) Assert.all_equal(x, x1) -@requires_cuda -def test_triton_copy_cast(): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(7563, dtype=torch.bfloat16, device="cuda") +@requires_triton +def test_triton_copy_cast(testing_device): + x = torch.randn(7563, dtype=torch.float32, device=testing_device).to(torch.float16) x1 = x.clone() y = torch.zeros_like(x, dtype=torch.float32) Assert.all_different(x.float(), y) - triton_copy(x, y) + triton_copy(x, y, use_triton=True) Assert.rms_close(x, y, 1e-4) Assert.all_equal(x, x1) -@requires_cuda -def test_triton_add(): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(8934, dtype=torch.float32, device="cuda") +@requires_triton +def test_triton_add(testing_device): + x = torch.randn(8934, dtype=torch.float32, device=testing_device) x1 = x.clone() y = torch.zeros_like(x) y1 = y.clone() Assert.all_different(x, y) - z = triton_add(x, y) + z = triton_add(x, y, use_triton=True) z1 = x1 + y1 Assert.rms_close(z, z1, 1e-5) Assert.all_equal(x, x1) Assert.all_equal(y, y1) -@requires_cuda +@requires_triton @pytest.mark.parametrize( ("batch_size", "sequence_length", "num_heads", "head_size"), - [(4, 1024, 8, 128), (1, 32, 1, 16), (2, 2048, 2, 192), (3, 519, 7, 134), (2, 100000, 2, 4)], + [(4, 32, 2, 16), (1, 32, 1, 16), (2, 64, 2, 96), (3, 59, 7, 22)], ) -def test_triton_rotary(batch_size, sequence_length, num_heads, head_size): - assert TritonConfig.TRITON_ENABLED - x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.float32, device="cuda") +def test_triton_rotary(batch_size, sequence_length, num_heads, head_size, testing_device): + x = torch.randn(batch_size, sequence_length, num_heads, head_size, dtype=torch.float32, device=testing_device) frequencies = ( DefaultRotaryConfig() .get_layer(TensorDim("", head_size)) ._get_frequencies( sequence_length, head_size, - device="cuda", + device=testing_device, ) ) @@ -110,15 +105,14 @@ def test_triton_rotary(batch_size, sequence_length, num_heads, head_size): Assert.rms_close(y_real, y_triton, 1e-4) -@requires_cuda +@requires_triton @pytest.mark.parametrize("has_bias", [True, False]) @pytest.mark.parametrize("zero_centered", [True, False]) -def test_triton_normalization(has_bias, zero_centered): - assert TritonConfig.TRITON_ENABLED - input_ = torch.randn(4096, 1024, device="cuda", requires_grad=True) +def test_triton_normalization(has_bias, zero_centered, testing_device): + input_ = torch.randn(32, 128, device=testing_device, requires_grad=True) output_grad = torch.randn_like(input_) - weight = torch.randn(1024, device="cuda", requires_grad=True) + weight = torch.randn(128, device=testing_device, requires_grad=True) weight.grad_buffer = torch.empty_like(weight) weight.param_grad_is_zero = True @@ -160,7 +154,7 @@ def test_triton_normalization(has_bias, zero_centered): Assert.rms_close(bias_grad0, bias.grad, 1e-3) -@requires_cuda +@requires_triton @pytest.mark.parametrize("gated", [True, False]) @pytest.mark.parametrize( "activation", @@ -173,10 +167,9 @@ def test_triton_normalization(has_bias, zero_centered): ], ) @pytest.mark.parametrize("recompute", [True, False]) -def test_triton_mlp_activation(gated, activation, recompute): - assert TritonConfig.TRITON_ENABLED - input_ = torch.randn(1024, 4096 * (2 if gated else 1), device="cuda", requires_grad=True) - output_grad = torch.randn(1024, 4096, device="cuda") +def test_triton_mlp_activation(gated, activation, recompute, testing_device): + input_ = torch.randn(32, 128 * (2 if gated else 1), device=testing_device, requires_grad=True) + output_grad = torch.randn(32, 128, device=testing_device) output1, context = triton_mlp_activation_forward(input_, gated, activation) input_grad1, output3 = triton_mlp_activation_backward(output_grad, context, recompute) @@ -190,10 +183,9 @@ def test_triton_mlp_activation(gated, activation, recompute): Assert.rms_close(output1, output3, 1e-5) -@requires_cuda -def test_triton_adam(): - assert TritonConfig.TRITON_ENABLED - params = torch.randn(4576427, dtype=torch.float32, device="cuda") +@requires_triton +def test_triton_adam(testing_device): + params = torch.randn(45764, dtype=torch.float32, device=testing_device) grads = torch.randn_like(params) exp_avgs = torch.randn_like(params) exp_avg_sqs = torch.randn_like(params).abs() @@ -248,13 +240,14 @@ def compare(i, j, fn, arg): compare(0, 4, Assert.eq, 0) +# TODO: Failing with triton interpreter @requires_cuda @pytest.mark.parametrize( ("num_rows_dense", "num_experts", "num_experts_per_token"), [(2048, 8, 2), (2048, 6, 2), (2048, 8, 8), (256, 8, 2), (5627, 8, 2)], ) -def test_triton_sparse_map(num_rows_dense, num_experts, num_experts_per_token): - logits = torch.randn((num_rows_dense, num_experts), device="cuda") +def test_triton_sparse_map(num_rows_dense, num_experts, num_experts_per_token, testing_device): + logits = torch.randn((num_rows_dense, num_experts), device=testing_device) _, top_experts = torch.topk(logits, num_experts_per_token, dim=-1) sparse_map_triton = get_sparse_map(top_experts, num_experts, use_triton=True) diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index e54197204..9ca78d64e 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -9,10 +9,11 @@ from fast_llm.engine.config_utils import data_type from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.distributed.config import DistributedBackend -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig +from fast_llm.functional.config import EntropyLossType, TargetFormat +from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward from fast_llm.functional.triton import triton_available +from fast_llm.functional.triton.cross_entropy import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss -from fast_llm.layers.language_model.loss.entropy_loss import entropy_loss_forward_backward from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward from fast_llm.utils import Assert @@ -104,8 +105,10 @@ def reference_dpo_loss( return -torch.nn.functional.logsigmoid(beta * (pi_logratios - ref_logratios)).mean() -_BATCH_SHAPES = ((64,), (16, 8)) +# _BATCH_SHAPES = ((64,), (16, 8)) +_BATCH_SHAPES = ((1,),) _LOSS_PARAMETERS = ( + (8, 1.0, 1.0, False, DataType.float32, None), # Simple (500, 1.0, 1.0, False, DataType.float32, None), # Simple (256, 1.0, 1.0, False, DataType.float32, None), # Power of 2 (500, None, 1.0, False, DataType.float32, None), # No grad @@ -131,13 +134,13 @@ def _test_entropy_loss( group=None, ): if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: - pytest.skip(reason="Not implemented") + pytest.skip(reason="Reverse KL loss not implemented for target labels") # TODO: Test tensor-parallel implementation. logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, target_format, batch_shape, dtype) local_logits = split_op(logits, group, -1).contiguous() local_target = target if target_format == TargetFormat.labels else split_op(target, group, -1).contiguous() # Torch serves as the reference implementation. - out_ref, grad_ref = entropy_loss_forward_backward( + out_ref, grad_ref = torch_entropy_loss_forward_backward( logits=logits, target=target, loss_mask=loss_mask, @@ -145,9 +148,8 @@ def _test_entropy_loss( logits_scale_factor=logits_scale_factor, target_format=target_format, entropy_loss_type=entropy_loss_type, - implementation=EntropyLossImplementation.torch, ) - out_fused, grad_fused = entropy_loss_forward_backward( + out_fused, grad_fused = fused_entropy_loss_forward_backward( logits=local_logits, target=local_target, loss_mask=loss_mask, @@ -156,7 +158,6 @@ def _test_entropy_loss( logits_scale_factor=logits_scale_factor, target_format=target_format, entropy_loss_type=entropy_loss_type, - implementation=EntropyLossImplementation.fused, ) _compare_losses_and_grads( out_fused, @@ -168,11 +169,9 @@ def _test_entropy_loss( group=group, ) - if entropy_loss_type == EntropyLossType.reverse_kl or not triton_available: - # Triton implementation only supports cross-entropy. + if not triton_available: return - assert TritonConfig.TRITON_ENABLED - out_triton, grad_triton = entropy_loss_forward_backward( + out_triton, grad_triton = triton_entropy_loss_forward_backward( logits=local_logits, target=local_target, loss_mask=loss_mask, @@ -180,11 +179,18 @@ def _test_entropy_loss( logits_scale_factor=logits_scale_factor, target_format=target_format, entropy_loss_type=entropy_loss_type, - implementation=EntropyLossImplementation.triton, group=group, block_size=block_size, ) - _compare_losses_and_grads(out_triton, out_ref, grad_output is not None, grad_triton, grad_ref, group=group) + _compare_losses_and_grads( + out_triton, + out_ref, + grad_output is not None, + grad_triton, + grad_ref, + threshold=1e-5 if target_format != TargetFormat.probabilities and data_type == DataType.float32 else 1e-4, + group=group, + ) def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): diff --git a/tests/layers/test_rotary.py b/tests/layers/test_rotary.py index 112c88a66..f34b9a35d 100644 --- a/tests/layers/test_rotary.py +++ b/tests/layers/test_rotary.py @@ -8,24 +8,27 @@ from fast_llm.utils import Assert -def test_rotary_2d(): +def test_rotary_2d(testing_device): """ Compare Fast-LLM's implementation of 2d rotary embeddings with Pixtral. """ head_dim = 16 num_heads = 8 - device = "cuda" if torch.cuda.is_available() else "cpu" patch_positions = torch.tensor( [[h, w] for h in range(4) for w in range(4)], dtype=torch.int64, - device=device, + device=testing_device, ) - query = torch.empty(2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device=device).normal_() + query = torch.empty( + 2, len(patch_positions), num_heads, head_dim, dtype=torch.float32, device=testing_device + ).normal_() key = torch.empty_like(query).normal_() pixtral_config = transformers.PixtralVisionConfig(hidden_size=head_dim * num_heads, num_attention_heads=num_heads) - pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to(device) + pixtral_rotary = transformers.models.pixtral.modeling_pixtral.PixtralRotaryEmbedding(pixtral_config).to( + testing_device + ) # Convert patch positions (h, w) to Pixtral's linear position IDs # Pixtral expects: position_id = h * max_patches_per_side + w position_ids = ( @@ -37,7 +40,7 @@ def test_rotary_2d(): ) fast_llm_rotary = Rotary2DConfig().get_layer(TensorDim("head_dim", head_dim)) - kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: device} + kwargs = {VisionKwargs.patch_positions: patch_positions, AttentionKwargs.device: testing_device} fast_llm_rotary.preprocess(kwargs) output_fast_llm_query, output_fast_llm_key = fast_llm_rotary.forward(query, key, kwargs) diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index 777214aae..c12fe52e9 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -85,7 +85,7 @@ def _compare_mixers( @pytest.mark.slow # Arguments ('seq_idx',) not implemented for torch implementation of 1d convolution. @pytest.mark.skipif(not transformers.utils.import_utils.is_causal_conv1d_available(), reason="GDN deps missing") -def test_gdn(): +def test_gdn(testing_device): dtype = torch.bfloat16 NUM_V_HEADS = 4 @@ -103,7 +103,7 @@ def test_gdn(): hf_layer = ( Apriel2GatedDeltaNet(HIDDEN_SIZE, {**config_common, "norm_eps": 1e-5}, layer_idx=0, dtype=dtype) - .to(device="cuda" if torch.cuda.is_available() else "cpu", dtype=dtype) + .to(device=testing_device, dtype=dtype) .eval() ) fast_llm_config = GatedDeltaNetConfig.from_dict(config_common, {"normalization": {"epsilon": 1e-5}}) diff --git a/tests/models/test_checkpoint.py b/tests/models/test_checkpoint.py index 955fa534c..1da264739 100644 --- a/tests/models/test_checkpoint.py +++ b/tests/models/test_checkpoint.py @@ -231,14 +231,14 @@ def do_load_and_compare_checkpoints( @pytest.mark.depends_on(on=["test_conversion[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) def test_load_pretrained( - model_testing_config, run_test_script_base_path, get_convert_path, load_and_compare_checkpoints + model_testing_config, run_test_script_base_path, get_convert_path, load_and_compare_checkpoints, testing_device ): # Test that loadind a pretrained model from either converted checkpoint always yields the exact same model. reference_config = model_testing_config.model_config_class.from_dict( yaml.safe_load(get_convert_path().parents[1].joinpath("config.yaml").open("r"))["model"] ) reference_shard = safetensors.torch.load_file( - get_convert_path() / "rank_0.safetensors", device="cuda" if torch.cuda.is_available() else "cpu" + get_convert_path() / "rank_0.safetensors", device=str(testing_device) )[_WEIGHT_SHARD_SAVE_NAME] load_and_compare_checkpoints( FastLLMCheckpointFormat, @@ -304,8 +304,7 @@ def test_load_pretrained( @pytest.mark.depends_on(on=["test_load_pretrained[{model_testing_config}]"]) @pytest.mark.model_testing_group(ModelTestingGroup.convert) -def test_huggingface_model(model_testing_config, get_convert_path): - device = "cuda" if torch.cuda.is_available() else "cpu" +def test_huggingface_model(model_testing_config, get_convert_path, testing_device): distributed_update = {("distributed", "use_cuda"): torch.cuda.is_available()} if model_testing_config.checkpoint_format is None: return @@ -331,11 +330,11 @@ def test_huggingface_model(model_testing_config, get_convert_path): 384, size=(4, 100), dtype=torch.int64, - device=device, + device=testing_device, ) kwargs = {} if model_testing_config.model_type == "multimodal": - kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).to(device) + kwargs["pixel_values"] = torch.rand([6, 3, 20, 20]).to(testing_device) kwargs["image_sizes"] = torch.tensor( [ [20, 20], # Full image, 25 patches @@ -373,7 +372,7 @@ def test_huggingface_model(model_testing_config, get_convert_path): errors = [] model_as_hf = ( model_testing_config.auto_model_class.from_pretrained(hf_path, trust_remote_code=True) - .to("cuda" if torch.cuda.is_available() else "cpu") + .to(testing_device) .eval() ) for name, model in zip( diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py index ca92f0b74..8c131dfa7 100644 --- a/tests/test_loss_mask.py +++ b/tests/test_loss_mask.py @@ -15,7 +15,7 @@ from fast_llm.engine.distributed.config import PhaseType from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.models.gpt.config import GPTBatchConfig, GPTModelConfig -from tests.utils.utils import get_base_model, requires_cuda +from tests.utils.utils import get_base_model def create_test_batch( @@ -46,7 +46,7 @@ def get_minimal_model(): "embeddings": {"vocab_size": 1000}, "hidden_size": 64, }, - "distributed": {}, + "distributed": {"use_cuda": torch.cuda.is_available()}, }, ) model, distributed = get_base_model(config) @@ -82,7 +82,6 @@ def run_preprocess_batch(model, distributed_config, batch: LanguageModelBatch, p ) -@requires_cuda class TestLossMaskIntegration: """ Integration tests for loss_mask computation in preprocess_batch. diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 5a6aff831..a4f28d14c 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -211,6 +211,8 @@ def update_and_add_testing_config( "save": True, "show": False, }, + # Triton kernels are extremely slow in interpreter mode. + "enable_triton_kernels": torch.cuda.is_available(), # Uncomment to enable model debug logging: # "model_debug_level": _LOG_LEVEL, }, diff --git a/tests/utils/utils.py b/tests/utils/utils.py index f0ca20db8..da293e1df 100644 --- a/tests/utils/utils.py +++ b/tests/utils/utils.py @@ -9,11 +9,13 @@ from fast_llm.engine.distributed.distributed import Distributed from fast_llm.engine.multi_stage.config import FastLLMModelConfig, StageConfig from fast_llm.engine.multi_stage.stage import Stage +from fast_llm.functional.triton import triton_available from tests.utils.global_variables import TEST_RESULTS_PATH logger = logging.getLogger(__name__) requires_cuda = pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is not available") +requires_triton = pytest.mark.skipif(not triton_available, reason="Triton is not available") @pytest.fixture(scope="session") From d99511b2ee83a7d41c0c68ebead9e08e6ddb5343 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Feb 2026 04:21:09 -0500 Subject: [PATCH 13/22] rename --- .../functional/triton/{cross_entropy.py => entropy_loss.py} | 0 fast_llm/layers/language_model/loss/entropy_loss.py | 2 +- tests/layers/test_lm_losses.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) rename fast_llm/functional/triton/{cross_entropy.py => entropy_loss.py} (100%) diff --git a/fast_llm/functional/triton/cross_entropy.py b/fast_llm/functional/triton/entropy_loss.py similarity index 100% rename from fast_llm/functional/triton/cross_entropy.py rename to fast_llm/functional/triton/entropy_loss.py diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 25e0c19b6..f81e4e4ba 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -4,7 +4,7 @@ from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward -from fast_llm.functional.triton.cross_entropy import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.config import ( LanguageModelDistillationLossConfig, LanguageModelLabelEntropyLossConfig, diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 9ca78d64e..2e8786919 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -12,7 +12,7 @@ from fast_llm.functional.config import EntropyLossType, TargetFormat from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward from fast_llm.functional.triton import triton_available -from fast_llm.functional.triton.cross_entropy import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.loss import loss_forward_backward from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward From 094ac85615c81c30305bf142a5fed05ea4a43c41 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Feb 2026 05:12:54 -0500 Subject: [PATCH 14/22] Z loss --- fast_llm/functional/triton/__init__.py | 1 + fast_llm/functional/triton/entropy_loss.py | 9 +- fast_llm/functional/triton/z_loss.py | 138 ++++++++++++++++++ fast_llm/layers/language_model/loss/z_loss.py | 8 +- tests/layers/test_lm_losses.py | 31 +++- 5 files changed, 176 insertions(+), 11 deletions(-) create mode 100644 fast_llm/functional/triton/z_loss.py diff --git a/fast_llm/functional/triton/__init__.py b/fast_llm/functional/triton/__init__.py index 61ead1c60..f5b394bfb 100644 --- a/fast_llm/functional/triton/__init__.py +++ b/fast_llm/functional/triton/__init__.py @@ -38,3 +38,4 @@ def tl_full(shape, value, dtype): else: tl_arange = tl.arange + tl_full = tl.full diff --git a/fast_llm/functional/triton/entropy_loss.py b/fast_llm/functional/triton/entropy_loss.py index 335048770..ad826f3e6 100644 --- a/fast_llm/functional/triton/entropy_loss.py +++ b/fast_llm/functional/triton/entropy_loss.py @@ -636,7 +636,7 @@ def _rescale_sum_exp_logits( return sum_exp_logits * (local_max_logits - max_logits).exp() -def _parallel_sum_exp_logits( +def parallel_sum_exp_logits( sum_exp_logits: torch.Tensor, local_max_logits: torch.Tensor, group: torch.distributed.ProcessGroup | None, @@ -758,7 +758,7 @@ def triton_entropy_loss_forward_backward( col_min=n_cols * group.rank(), **kwargs, ) - max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) torch.distributed.all_reduce(partial_losses, op=torch.distributed.ReduceOp.SUM, group=group) loss = _cross_entropy_loss_from_labels(partial_losses, target, sum_exp_logits, max_logits) if grad_output is not None: @@ -827,11 +827,10 @@ def triton_entropy_loss_forward_backward( target_logits_scale_factor=logits_scale_factor / temperature, from_logits=target_format == TargetFormat.logits, **kwargs, - **backward_kwargs, ) - max_logits, sum_exp_logits = _parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) if target_format == TargetFormat.logits: - target_max_logits, target_sum_exp_logits = _parallel_sum_exp_logits( + target_max_logits, target_sum_exp_logits = parallel_sum_exp_logits( target_sum_exp_logits, local_target_max_logits, group ) if entropy_loss_type != EntropyLossType.reverse_kl: diff --git a/fast_llm/functional/triton/z_loss.py b/fast_llm/functional/triton/z_loss.py new file mode 100644 index 000000000..298c3c2a7 --- /dev/null +++ b/fast_llm/functional/triton/z_loss.py @@ -0,0 +1,138 @@ +import torch + +from fast_llm.functional.triton import tl, tl_arange, tl_constexpr, triton, triton_jit +from fast_llm.functional.triton.entropy_loss import ( + parallel_sum_exp_logits, + triton_cross_entropy_forward_from_labels_parallel_kernel, + triton_fused_softmax_base, +) + + +@triton_jit() +def triton_z_loss_forward_backward_kernel( + logits_ptr, + loss_mask_ptr, + n_cols: tl_constexpr, + logits_stride_0: tl_constexpr, + block_size: tl_constexpr, + losses_ptr=None, + max_logits_ptr=None, + sum_exp_logits_ptr=None, + grad_losses=None, + grad_logits_ptr=None, + grad_logits_stride_0: tl_constexpr = None, + logits_scale_factor: tl_constexpr = 1.0, +): + # TODO: Int64 ptr only if needed? + block_idx = tl.program_id(0).to(tl.int64) + logits_ptr = logits_ptr + block_idx * logits_stride_0 + + if loss_mask_ptr is not None and tl.load(loss_mask_ptr + block_idx) == 0: + # This entry is masked, ignore. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + if max_logits_ptr is None or sum_exp_logits_ptr is None: + exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_base( + logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor + ) + else: + max_logits = tl.load(max_logits_ptr + block_idx) + sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) + + log_sum_exp_logits = tl.log(sum_exp_logits) + max_logits + + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, log_sum_exp_logits * log_sum_exp_logits) + + if grad_losses is not None: + if logits_scale_factor != 1.0: + grad_losses *= logits_scale_factor + grad_losses *= 2 * log_sum_exp_logits / sum_exp_logits + # Run in reverse order to maximize input and cache reuse. + col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size + for col_offset in tl.static_range(col_offset_start, -1, -block_size): + if max_logits_ptr is not None or sum_exp_logits_ptr is not None or col_offset != col_offset_start: + col_offsets = tl_arange(col_offset, col_offset + block_size) + mask = col_offsets < n_cols + logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) + if logits_scale_factor != 1.0: + logits *= logits_scale_factor + exp_logits = tl.exp(logits - max_logits) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, exp_logits * grad_losses, mask=mask + ) + + +def triton_z_loss_forward_backward( + logits: torch.Tensor, + loss_mask: torch.Tensor | None, + grad_output: float | None, + group: torch.distributed.ProcessGroup | None = None, + logits_scale_factor: float = 1.0, + block_size: int | None = None, + num_warps: int | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + assert logits.is_contiguous() + if loss_mask is not None: + assert loss_mask.is_contiguous() + n_rows = logits.shape[:-1].numel() + n_cols = logits.size(-1) + if block_size is None: + block_size = min(triton.next_power_of_2(n_cols), 32768) + if num_warps is None: + num_warps = 4 if block_size < 2048 else (8 if block_size < 8192 else 16) + kwargs = { + "logits_stride_0": logits.stride(-2), + "n_cols": n_cols, + "logits_scale_factor": logits_scale_factor, + "block_size": block_size, + "num_warps": num_warps, + } + grad_logits = None if grad_output is None else torch.empty_like(logits) + backward_kwargs = ( + {} + if grad_output is None + else { + "grad_logits_ptr": grad_logits, + "grad_losses": grad_output / n_rows, + "grad_logits_stride_0": grad_logits.stride(-2), + } + ) + losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) + if group is None: + triton_z_loss_forward_backward_kernel[(n_rows,)]( + logits, + loss_mask_ptr=loss_mask, + losses_ptr=losses, + **kwargs, + **backward_kwargs, + ) + else: + local_max_logits = torch.empty(n_rows, dtype=torch.float, device=logits.device) + sum_exp_logits = torch.empty_like(local_max_logits) + triton_cross_entropy_forward_from_labels_parallel_kernel[(n_rows,)]( + logits, + None, + max_logits_ptr=local_max_logits, + sum_exp_logits_ptr=sum_exp_logits, + **kwargs, + ) + max_logits, sum_exp_logits = parallel_sum_exp_logits(sum_exp_logits, local_max_logits, group) + triton_z_loss_forward_backward_kernel[(n_rows,)]( + logits, + loss_mask_ptr=loss_mask, + losses_ptr=losses, + max_logits_ptr=max_logits, + sum_exp_logits_ptr=sum_exp_logits, + **kwargs, + **backward_kwargs, + ) + return losses.mean(), grad_logits diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index 82b8d5318..0e675d338 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -2,7 +2,9 @@ import torch +from fast_llm.functional.config import TritonConfig from fast_llm.functional.entropy_loss import fused_softmax_base +from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.config import LanguageModelZLossConfig from fast_llm.layers.language_model.loss.loss import LanguageModelLoss @@ -20,7 +22,9 @@ def forward_backward( kwargs: dict[str, typing.Any], split_index: int = 0, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - return z_loss_forward_backward( + return ( + triton_z_loss_forward_backward if TritonConfig.enabled(logits.device) else fused_z_loss_forward_backward + )( logits, self._get_loss_mask(kwargs, split_index), grad_output=self._get_grad_output(kwargs), @@ -44,7 +48,7 @@ def z_loss( @torch.compile -def z_loss_forward_backward( +def fused_z_loss_forward_backward( logits: torch.Tensor, loss_mask: torch.Tensor | None, grad_output: float | None, diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 2e8786919..2f04a38ee 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -13,9 +13,10 @@ from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward from fast_llm.functional.triton import triton_available from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward +from fast_llm.functional.triton.z_loss import triton_z_loss_forward_backward from fast_llm.layers.language_model.loss.dpo import dpo_loss from fast_llm.layers.language_model.loss.loss import loss_forward_backward -from fast_llm.layers.language_model.loss.z_loss import z_loss, z_loss_forward_backward +from fast_llm.layers.language_model.loss.z_loss import fused_z_loss_forward_backward, z_loss from fast_llm.utils import Assert from tests.utils.dataset import get_random_spans from tests.utils.subtest import DistributedTestContext @@ -193,7 +194,9 @@ def _test_entropy_loss( ) -def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, group=None): +def _test_z_loss( + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, group=None +): logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, TargetFormat.logits, batch_shape, dtype) out_ref, grad_ref = loss_forward_backward( grad_output, @@ -202,7 +205,7 @@ def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, los loss_mask, logits_scale_factor=logits_scale_factor, ) - out_fused, grad_fused = z_loss_forward_backward( + out_fused, grad_fused = fused_z_loss_forward_backward( split_op(logits, group, -1), loss_mask, grad_output, @@ -218,6 +221,25 @@ def _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, los threshold=1e-5 if data_type == DataType.float32 else 1e-4, group=group, ) + if not triton_available: + return + out_triton, grad_triton = triton_z_loss_forward_backward( + split_op(logits, group, -1), + loss_mask, + grad_output, + group=group, + logits_scale_factor=logits_scale_factor, + block_size=block_size, + ) + _compare_losses_and_grads( + out_triton, + out_ref, + grad_output is not None, + grad_triton, + grad_ref, + threshold=1e-5 if data_type == DataType.float32 else 1e-4, + group=group, + ) @pytest.mark.slow @@ -257,7 +279,7 @@ def test_entropy_loss( ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS ) def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size): - _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype) + _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size) @pytest.mark.skip(reason="DPO loss is broken") @@ -309,6 +331,7 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa logits_scale_factor, loss_masking, dtype, + block_size, test_context.group, ) From 35fd220a7661746884db1cc2a2f0655d9c8788b1 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Fri, 6 Feb 2026 05:14:16 -0500 Subject: [PATCH 15/22] fix --- fast_llm/layers/language_model/loss/config.py | 6 ++++++ fast_llm/layers/language_model/loss/z_loss.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index d636d6af7..970003122 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -160,6 +160,12 @@ class LanguageModelZLossConfig(LanguageModelLossConfig): _abstract: typing.ClassVar[bool] = False + use_triton: bool | None = Field( + default=None, + desc="Enable triton implementation. Default: use if available.", + hint=FieldHint.expert, + ) + @property def loss_class(self) -> "type[LanguageModelZLoss]": from fast_llm.layers.language_model.loss.z_loss import LanguageModelZLoss diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index 0e675d338..1df54f7a5 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -23,7 +23,9 @@ def forward_backward( split_index: int = 0, ) -> "tuple[torch.Tensor, torch.Tensor | None]": return ( - triton_z_loss_forward_backward if TritonConfig.enabled(logits.device) else fused_z_loss_forward_backward + triton_z_loss_forward_backward + if TritonConfig.enabled(logits.device, self._config.use_triton) + else fused_z_loss_forward_backward )( logits, self._get_loss_mask(kwargs, split_index), From 9133fcd0967a882871675c22570ac0b464eebfa4 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sat, 7 Feb 2026 07:11:37 -0500 Subject: [PATCH 16/22] Grad accumulation --- fast_llm/functional/entropy_loss.py | 17 ++-- fast_llm/functional/triton/entropy_loss.py | 97 ++++++++++--------- fast_llm/functional/triton/z_loss.py | 28 +++--- fast_llm/layers/language_model/head.py | 6 +- fast_llm/layers/language_model/loss/config.py | 15 +++ fast_llm/layers/language_model/loss/dpo.py | 10 +- .../language_model/loss/entropy_loss.py | 77 +++------------ fast_llm/layers/language_model/loss/loss.py | 1 + fast_llm/layers/language_model/loss/z_loss.py | 15 ++- tests/layers/test_lm_losses.py | 86 +++++++++++----- 10 files changed, 190 insertions(+), 162 deletions(-) diff --git a/fast_llm/functional/entropy_loss.py b/fast_llm/functional/entropy_loss.py index 4d39b3a77..65dcee32b 100644 --- a/fast_llm/functional/entropy_loss.py +++ b/fast_llm/functional/entropy_loss.py @@ -278,12 +278,13 @@ def fused_entropy_loss_forward_backward( logits: torch.Tensor, # (*batch, vocab) target: torch.Tensor, # (*batch,) or (*batch, vocab) loss_mask: torch.Tensor | None, # (*batch,) - grad_output: float | None, - logits_scale_factor: float, - target_format: TargetFormat, - entropy_loss_type: EntropyLossType, - group: ProcessGroup | None = None, + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, + group: torch.distributed.ProcessGroup | None = None, + logits_scale_factor: float = 1.0, temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ A fused implementation of cross-entropy with torch compile. @@ -335,5 +336,9 @@ def fused_entropy_loss_forward_backward( if loss_mask is not None: grad = grad * loss_mask.unsqueeze(-1) grad = grad.to(logits.dtype) + if grad_logits is None: + grad_logits = grad + else: + grad_logits.add_(grad) - return loss, grad + return loss, grad_logits diff --git a/fast_llm/functional/triton/entropy_loss.py b/fast_llm/functional/triton/entropy_loss.py index ad826f3e6..3d9937439 100644 --- a/fast_llm/functional/triton/entropy_loss.py +++ b/fast_llm/functional/triton/entropy_loss.py @@ -111,11 +111,27 @@ def triton_cross_entropy_forward_backward_from_labels_kernel( grad_logits_stride_0: tl_constexpr = None, col_min: tl_constexpr = 0, logits_scale_factor: tl_constexpr = 1.0, + accumulate: tl_constexpr = False, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) logits_ptr = logits_ptr + block_idx * logits_stride_0 + label_idx = tl.load(labels_ptr + block_idx) + if label_idx < 0: + # This entry is masked, ignore. + if losses_ptr is not None: + tl.store(losses_ptr + block_idx, 0) + if grad_losses is not None and not accumulate: + for col_offset in tl.static_range(0, n_cols, block_size): + col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) + tl.store( + grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, 0, mask=col_offsets < n_cols + ) + return + + label_idx -= col_min + if max_logits_ptr is None or sum_exp_logits_ptr is None: exp_logits, sum_exp_logits, max_logits, col_offsets, mask = triton_fused_softmax_base( logits_ptr, n_cols=n_cols, block_size=block_size, logits_scale_factor=logits_scale_factor @@ -124,8 +140,6 @@ def triton_cross_entropy_forward_backward_from_labels_kernel( max_logits = tl.load(max_logits_ptr + block_idx) sum_exp_logits = tl.load(sum_exp_logits_ptr + block_idx) - label_idx = tl.load(labels_ptr + block_idx) - col_min - if losses_ptr is not None: if label_idx < 0 or label_idx >= n_cols: # Loss mask @@ -138,9 +152,7 @@ def triton_cross_entropy_forward_backward_from_labels_kernel( tl.store(losses_ptr + block_idx, loss) if grad_losses is not None: - if label_idx < -col_min: - grad_losses = 0.0 - elif logits_scale_factor != 1.0: + if logits_scale_factor != 1.0: grad_losses *= logits_scale_factor # Run in reverse order to maximize input and cache reuse. col_offset_start: tl.constexpr = (n_cols - 1) // block_size * block_size @@ -158,9 +170,11 @@ def triton_cross_entropy_forward_backward_from_labels_kernel( grad_logits = grad_base else: grad_logits = tl.where(col_offsets == label_idx, grad_base - 1.0, grad_base) - tl.store( - grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits * grad_losses, mask=mask - ) + grad_logits *= grad_losses + grad_logits_col_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets + if accumulate: + grad_logits += tl.load(grad_logits_col_ptr, mask=mask) + tl.store(grad_logits_col_ptr, grad_logits, mask=mask) @triton_jit() @@ -315,6 +329,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( logits_scale_factor: tl_constexpr = 1.0, target_logits_scale_factor: tl_constexpr = 1.0, return_kl_loss: tl.constexpr = False, + accumulate: tl_constexpr = False, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -325,7 +340,7 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( # This entry is masked, ignore. if losses_ptr is not None: tl.store(losses_ptr + block_idx, 0) - if grad_losses is not None: + if grad_losses is not None and not accumulate: for col_offset in tl.static_range(0, n_cols, block_size): col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) tl.store( @@ -391,7 +406,10 @@ def triton_cross_entropy_from_distribution_forward_backward_kernel( target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) grad_logits = grad_losses * (exp_logits / sum_exp_logits - target) - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + grad_logits_col_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets + if accumulate: + grad_logits += tl.load(grad_logits_col_ptr, mask=mask) + tl.store(grad_logits_col_ptr, grad_logits, mask=mask) @triton_jit() @@ -424,9 +442,6 @@ def triton_reverse_kl_forward_from_distribution( sum_exp_logits=sum_exp_logits, logits_scale_factor=logits_scale_factor, ) - - # print("sum_exp_logits", sum_exp_logits) - # print("max_logits", new_max_logits) if from_logits: # log_target excludes the log_sum_exp term to be added later log_target, _, target_sum_exp_logits, target_new_max_logits, _, _ = triton_fused_softmax_iter_base( @@ -441,20 +456,11 @@ def triton_reverse_kl_forward_from_distribution( mask=mask, ) target = log_target - # print("target_sum_exp_logits", target_sum_exp_logits) - # print("new_max_logits", target_new_max_logits) else: target = tl.load(target_ptr + col_offsets, mask=mask, other=-float("inf")).to(tl.float32) log_target = tl.log(target) if col_offset == 0: - # predicted_log_probability=logits - new_max_logits - tl.log(sum_exp_logits) - # target_log_probability=log_target-target_new_max_logits-tl.log(target_sum_exp_logits) - # print("predicted_log_probability", predicted_log_probability) - # print("target_log_probability", target_log_probability) - # print("IUWH", exp_logits * (predicted_log_probability-target_log_probability)/sum_exp_logits) loss = tl.sum(tl.where(mask, exp_logits * (logits - log_target), 0)) - # print("max_logits", new_max_logits) - # print("partial_losses", exp_logits * (logits-log_target)) else: loss = loss * tl.exp(max_logits - new_max_logits) + tl.sum( @@ -464,7 +470,6 @@ def triton_reverse_kl_forward_from_distribution( if from_logits: target_max_logits = target_new_max_logits - # print("partial_loss", loss) if not return_partial_loss: loss = loss / sum_exp_logits - tl.log(sum_exp_logits) - max_logits if from_logits: @@ -547,6 +552,7 @@ def triton_reverse_kl_forward_backward_kernel_from_distribution( grad_logits_stride_0: tl_constexpr = None, logits_scale_factor: tl_constexpr = 1.0, target_logits_scale_factor: tl_constexpr = 1.0, + accumulate: tl_constexpr = False, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -557,7 +563,7 @@ def triton_reverse_kl_forward_backward_kernel_from_distribution( # This entry is masked, ignore. if losses_ptr is not None: tl.store(losses_ptr + block_idx, 0) - if grad_losses is not None: + if grad_losses is not None and not accumulate: for col_offset in tl.static_range(0, n_cols, block_size): col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) tl.store( @@ -584,17 +590,9 @@ def triton_reverse_kl_forward_backward_kernel_from_distribution( target_max_logits = tl.load(target_max_logits_ptr + block_idx) target_sum_exp_logits = tl.load(target_sum_exp_logits_ptr + block_idx) - # print("sum_exp_logits", sum_exp_logits) - # print("max_logits", max_logits) - - # if from_logits: - # print("target_sum_exp_logits", target_sum_exp_logits) - # print("target_max_logits", target_max_logits) - if losses_ptr is not None: if partial_losses_ptr is not None: loss = tl.load(partial_losses_ptr + block_idx) - # print("partial_loss", loss) loss = loss / sum_exp_logits - tl.log(sum_exp_logits) - max_logits if from_logits: loss = loss + tl.log(target_sum_exp_logits) + target_max_logits @@ -624,7 +622,10 @@ def triton_reverse_kl_forward_backward_kernel_from_distribution( grad_logits = ( grad_losses * (predicted_log_probability - target_log_probability - loss) * predicted_probability ) - tl.store(grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, grad_logits, mask=mask) + grad_logits_col_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets + if accumulate: + grad_logits += tl.load(grad_logits_col_ptr, mask=mask) + tl.store(grad_logits_col_ptr, grad_logits, mask=mask) @torch.compile @@ -687,15 +688,16 @@ def _cross_entropy_loss_from_distribution( def triton_entropy_loss_forward_backward( - logits: torch.Tensor, - target: torch.Tensor, - loss_mask: torch.Tensor | None, - grad_output: float | None, - logits_scale_factor: float, - target_format: TargetFormat, - entropy_loss_type: EntropyLossType, + logits: torch.Tensor, # (*batch, vocab) + target: torch.Tensor, # (*batch,) or (*batch, vocab) + loss_mask: torch.Tensor | None, # (*batch,) + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, group: torch.distributed.ProcessGroup | None = None, + logits_scale_factor: float = 1.0, temperature: float = 1.0, + target_format: TargetFormat = TargetFormat.labels, + entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, block_size: int | None = None, num_warps: int | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: @@ -721,18 +723,17 @@ def triton_entropy_loss_forward_backward( "block_size": block_size, "num_warps": num_warps, } - - # TODO: Safe to do inplace? - grad_logits = None if grad_output is None else torch.empty_like(logits) - backward_kwargs = ( - {} - if grad_output is None - else { + if grad_output is None: + backward_kwargs = {} + else: + accumulate = grad_logits is not None + grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits + backward_kwargs = { "grad_logits_ptr": grad_logits, "grad_losses": grad_output / n_rows, "grad_logits_stride_0": grad_logits.stride(-2), + "accumulate": accumulate, } - ) if target_format == TargetFormat.labels: assert entropy_loss_type != EntropyLossType.reverse_kl if group is None: diff --git a/fast_llm/functional/triton/z_loss.py b/fast_llm/functional/triton/z_loss.py index 298c3c2a7..cb3220131 100644 --- a/fast_llm/functional/triton/z_loss.py +++ b/fast_llm/functional/triton/z_loss.py @@ -22,6 +22,7 @@ def triton_z_loss_forward_backward_kernel( grad_logits_ptr=None, grad_logits_stride_0: tl_constexpr = None, logits_scale_factor: tl_constexpr = 1.0, + accumulate: tl_constexpr = False, ): # TODO: Int64 ptr only if needed? block_idx = tl.program_id(0).to(tl.int64) @@ -31,7 +32,7 @@ def triton_z_loss_forward_backward_kernel( # This entry is masked, ignore. if losses_ptr is not None: tl.store(losses_ptr + block_idx, 0) - if grad_losses is not None: + if grad_losses is not None and not accumulate: for col_offset in tl.static_range(0, n_cols, block_size): col_offsets = tl_arange(int(col_offset), int(col_offset + block_size)) tl.store( @@ -66,15 +67,18 @@ def triton_z_loss_forward_backward_kernel( if logits_scale_factor != 1.0: logits *= logits_scale_factor exp_logits = tl.exp(logits - max_logits) - tl.store( - grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets, exp_logits * grad_losses, mask=mask - ) + grad_logits = exp_logits * grad_losses + grad_logits_col_ptr = grad_logits_ptr + block_idx * grad_logits_stride_0 + col_offsets + if accumulate: + grad_logits += tl.load(grad_logits_col_ptr, mask=mask) + tl.store(grad_logits_col_ptr, grad_logits, mask=mask) def triton_z_loss_forward_backward( logits: torch.Tensor, loss_mask: torch.Tensor | None, - grad_output: float | None, + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, group: torch.distributed.ProcessGroup | None = None, logits_scale_factor: float = 1.0, block_size: int | None = None, @@ -96,16 +100,18 @@ def triton_z_loss_forward_backward( "block_size": block_size, "num_warps": num_warps, } - grad_logits = None if grad_output is None else torch.empty_like(logits) - backward_kwargs = ( - {} - if grad_output is None - else { + if grad_output is None: + backward_kwargs = {} + else: + accumulate = grad_logits is not None + grad_logits = torch.empty_like(logits) if grad_logits is None else grad_logits + + backward_kwargs = { "grad_logits_ptr": grad_logits, "grad_losses": grad_output / n_rows, "grad_logits_stride_0": grad_logits.stride(-2), + "accumulate": accumulate, } - ) losses = torch.empty(n_rows, dtype=torch.float, device=logits.device) if group is None: triton_z_loss_forward_backward_kernel[(n_rows,)]( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index d9220d3e1..144074ca5 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -284,15 +284,13 @@ def _logits_loss_forward_backward_partial( losses, grad = {}, None for loss in self._losses: # losses are returned unscaled but the grads are already scaled - loss_value, grad_ = loss.forward_backward( + loss_value, grad = loss.forward_backward( logits, kwargs, split_index, + grad, ) losses[loss.name] = loss_value.detach() - if grad_ is not None: - # TODO: Accumulate grads in-place to reduce memory and compute overhead. - grad = grad_ if grad is None else grad + grad_ return losses, output_parallel_linear_backward(grad, context) if self.training else None diff --git a/fast_llm/layers/language_model/loss/config.py b/fast_llm/layers/language_model/loss/config.py index 970003122..b6a2ef175 100644 --- a/fast_llm/layers/language_model/loss/config.py +++ b/fast_llm/layers/language_model/loss/config.py @@ -1,4 +1,5 @@ import typing +import warnings from fast_llm.config import Config, Field, FieldHint, check_field, config_class from fast_llm.engine.distributed.config import DistributedConfig @@ -83,6 +84,13 @@ class LanguageModelLabelEntropyLossConfig(LanguageModelLossConfig): hint=FieldHint.expert, ) + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if "implementation" in default: + warnings.warn("`implementation` field is no longer supported for loss type `label`.") + del default["implementation"] + return super()._from_dict(default, strict) + @property def loss_class(self) -> "type[LanguageModelLabelEntropyLoss]": from fast_llm.layers.language_model.loss.entropy_loss import LanguageModelLabelEntropyLoss @@ -116,6 +124,13 @@ class LanguageModelDistillationLossConfig(LanguageModelLossConfig): hint=FieldHint.expert, ) + @classmethod + def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: + if "implementation" in default: + warnings.warn("`implementation` field is no longer supported for loss type `distillation`.") + del default["implementation"] + return super()._from_dict(default, strict) + @property def loss_class(self) -> "type[LanguageModelDistillationLoss]": from fast_llm.layers.language_model.loss.entropy_loss import LanguageModelDistillationLoss diff --git a/fast_llm/layers/language_model/loss/dpo.py b/fast_llm/layers/language_model/loss/dpo.py index 2194c6f86..177a681a4 100644 --- a/fast_llm/layers/language_model/loss/dpo.py +++ b/fast_llm/layers/language_model/loss/dpo.py @@ -23,12 +23,13 @@ def forward_backward( logits: "torch.Tensor", kwargs: dict[str, typing.Any], split_index: int = 0, + grad_logits: torch.Tensor | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": if self._get_loss_mask(kwargs, split_index) is not None: raise NotImplementedError() - return loss_forward_backward( + loss, grad = loss_forward_backward( self._get_grad_output(kwargs), dpo_loss, logits, @@ -39,6 +40,13 @@ def forward_backward( self._config.beta, ) + if grad is not None: + if grad_logits is None: + grad_logits = grad + else: + grad_logits.add_(grad) + return loss, grad_logits + def dpo_loss( logits: torch.Tensor, diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index f81e4e4ba..537c7996d 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -2,15 +2,14 @@ import torch -from fast_llm.functional.config import EntropyLossImplementation, EntropyLossType, TargetFormat, TritonConfig -from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward, torch_entropy_loss_forward_backward +from fast_llm.functional.config import TargetFormat, TritonConfig +from fast_llm.functional.entropy_loss import fused_entropy_loss_forward_backward from fast_llm.functional.triton.entropy_loss import triton_entropy_loss_forward_backward from fast_llm.layers.language_model.loss.config import ( LanguageModelDistillationLossConfig, LanguageModelLabelEntropyLossConfig, ) from fast_llm.layers.language_model.loss.loss import LanguageModelLoss -from fast_llm.utils import Assert class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): @@ -22,17 +21,22 @@ def forward_backward( logits: "torch.Tensor", kwargs: dict[str, typing.Any], split_index: int = 0, + grad_logits: torch.Tensor | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - return entropy_loss_forward_backward( + return ( + triton_entropy_loss_forward_backward + if TritonConfig.enabled(logits.device, self._config.use_triton) + else fused_entropy_loss_forward_backward + )( logits, self._get_labels(kwargs, split_index), None, # Labels are already masked + grad_logits=grad_logits, grad_output=self._get_grad_output(kwargs), group=self._parallel_dim.group if self._vocab_parallel else None, logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.labels, entropy_loss_type=self._config.loss_type, - use_triton=self._config.use_triton, ) @@ -47,69 +51,20 @@ def forward_backward( logits: "torch.Tensor", kwargs: dict[str, typing.Any], split_index: int = 0, + grad_logits: torch.Tensor | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - return entropy_loss_forward_backward( + return ( + triton_entropy_loss_forward_backward + if TritonConfig.enabled(logits.device, self._config.use_triton) + else fused_entropy_loss_forward_backward + )( logits, self._get_reference_model_logits(self._config.reference_model, kwargs, split_index), self._get_loss_mask(kwargs, split_index), grad_output=self._get_grad_output(kwargs), + grad_logits=grad_logits, group=self._parallel_dim.group if self._vocab_parallel else None, logits_scale_factor=self._logits_scale_factor, target_format=TargetFormat.logits, entropy_loss_type=self._config.loss_type, - use_triton=self._config.use_triton, ) - - -_ENTROPY_LOSS_IMPLEMENTATIONS = { - EntropyLossImplementation.torch: torch_entropy_loss_forward_backward, - EntropyLossImplementation.fused: fused_entropy_loss_forward_backward, - EntropyLossImplementation.triton: triton_entropy_loss_forward_backward, -} - - -def entropy_loss_forward_backward( - logits: torch.Tensor, # (*batch, vocab) - target: torch.Tensor, # (*batch,) or (*batch, vocab) - loss_mask: torch.Tensor | None, # (*batch,) - grad_output: float | None, - group: torch.distributed.ProcessGroup | None = None, - logits_scale_factor: float = 1.0, - temperature: float = 1.0, - target_format: TargetFormat = TargetFormat.labels, - entropy_loss_type: EntropyLossType = EntropyLossType.cross_entropy, - use_triton: bool | None = None, - **kwargs, -) -> tuple[torch.Tensor, torch.Tensor | None]: - """ - Select the appropriate implementation of cross-entropy. - The triton implementation from the triton submodule is the fastest and recommended one. - It doesn't have a tensor-parallel implementation, but can be computed in a sequence-tensor-parallel way, - which is faster and has a relatively small memory overhead. - """ - - if target_format == TargetFormat.labels: - Assert.eq(target.shape, logits.shape[:-1]) - Assert.eq(target.dtype, torch.int64) - assert loss_mask is None - else: - Assert.eq(target.shape, logits.shape) - assert target.dtype.is_floating_point, target.dtype - if loss_mask is not None: - Assert.eq(loss_mask.shape, logits.shape[:-1]) - return ( - triton_entropy_loss_forward_backward - if TritonConfig.enabled(logits.device, use_triton) - else fused_entropy_loss_forward_backward - )( - logits, - target, - loss_mask, - grad_output, - logits_scale_factor, - target_format, - entropy_loss_type, - group, - temperature=temperature, - **kwargs, - ) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 41e8942ac..766a5ed54 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -43,6 +43,7 @@ def forward_backward( logits: "torch.Tensor", kwargs: dict[str, typing.Any], split_index: int = 0, + grad_logits: torch.Tensor | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": pass diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index 1df54f7a5..c606e2d68 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -21,6 +21,7 @@ def forward_backward( logits: "torch.Tensor", kwargs: dict[str, typing.Any], split_index: int = 0, + grad_logits: torch.Tensor | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": return ( triton_z_loss_forward_backward @@ -32,6 +33,7 @@ def forward_backward( grad_output=self._get_grad_output(kwargs), group=self._parallel_dim.group if self._vocab_parallel else None, logits_scale_factor=self._logits_scale_factor, + grad_logits=grad_logits, ) @@ -53,7 +55,8 @@ def z_loss( def fused_z_loss_forward_backward( logits: torch.Tensor, loss_mask: torch.Tensor | None, - grad_output: float | None, + grad_logits: torch.Tensor | None = None, + grad_output: float | None = None, group: torch.distributed.ProcessGroup | None = None, logits_scale_factor: float = 1.0, ) -> tuple[torch.Tensor, torch.Tensor | None]: @@ -70,12 +73,14 @@ def fused_z_loss_forward_backward( per_sample_loss = per_sample_loss * loss_mask loss = per_sample_loss.mean() - if grad_output is None: - grad = None - else: + if grad_output is not None: grad_base = 2 * grad_output * (log_sum_exp_logits / sum_exp_logits) if loss_mask is not None: grad_base = grad_base * loss_mask grad = (grad_base.unsqueeze(-1) * exp_logits).to(logits.dtype) + if grad_logits is None: + grad_logits = grad + else: + grad_logits.add_(grad) - return loss, grad + return loss, grad_logits diff --git a/tests/layers/test_lm_losses.py b/tests/layers/test_lm_losses.py index 2f04a38ee..eefe5fbb4 100644 --- a/tests/layers/test_lm_losses.py +++ b/tests/layers/test_lm_losses.py @@ -109,16 +109,16 @@ def reference_dpo_loss( # _BATCH_SHAPES = ((64,), (16, 8)) _BATCH_SHAPES = ((1,),) _LOSS_PARAMETERS = ( - (8, 1.0, 1.0, False, DataType.float32, None), # Simple - (500, 1.0, 1.0, False, DataType.float32, None), # Simple - (256, 1.0, 1.0, False, DataType.float32, None), # Power of 2 - (500, None, 1.0, False, DataType.float32, None), # No grad - (500, 1.0, 4.0, False, DataType.float32, None), # Loss scaling - (500, 4.0, 1.0, False, DataType.float32, None), # Grad scaling - (500, 1.0, 1.0, True, DataType.float32, None), # Loss masking - (500, 1.0, 1.0, False, DataType.float16, None), # Fp16 - (500, 1.0, 1.0, False, DataType.float32, 256), # Looped - (1000, 2.0, 3.0, True, DataType.float16, 256), # Hard + (500, 1.0, 1.0, False, DataType.float32, None, False), # Simple + (256, 1.0, 1.0, False, DataType.float32, None, False), # Power of 2 + (500, None, 1.0, False, DataType.float32, None, False), # No grad + (500, 1.0, 1.0, False, DataType.float32, None, True), # Accumulate + (500, 1.0, 4.0, False, DataType.float32, None, False), # Loss scaling + (500, 4.0, 1.0, False, DataType.float32, None, False), # Grad scaling + (500, 1.0, 1.0, True, DataType.float32, None, False), # Loss masking + (500, 1.0, 1.0, False, DataType.float16, None, False), # Fp16 + (500, 1.0, 1.0, False, DataType.float32, 256, False), # Looped + (1000, 2.0, 3.0, True, DataType.float32, 256, True), # Hard ) @@ -132,6 +132,7 @@ def _test_entropy_loss( entropy_loss_type, dtype, block_size, + accumulate, group=None, ): if target_format == TargetFormat.labels and entropy_loss_type == EntropyLossType.reverse_kl: @@ -150,10 +151,15 @@ def _test_entropy_loss( target_format=target_format, entropy_loss_type=entropy_loss_type, ) + if accumulate: + previous_grad = torch.randn_like(grad_ref) + grad_ref = grad_ref + previous_grad + local_previous_grad = split_op(previous_grad, group, -1).contiguous() out_fused, grad_fused = fused_entropy_loss_forward_backward( logits=local_logits, target=local_target, loss_mask=loss_mask, + grad_logits=local_previous_grad.clone() if accumulate else None, grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, @@ -176,6 +182,7 @@ def _test_entropy_loss( logits=local_logits, target=local_target, loss_mask=loss_mask, + grad_logits=local_previous_grad.clone() if accumulate else None, grad_output=grad_output, logits_scale_factor=logits_scale_factor, target_format=target_format, @@ -195,20 +202,26 @@ def _test_entropy_loss( def _test_z_loss( - batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, group=None + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, accumulate, group=None ): logits, target, loss_mask = _get_lm_loss_inputs(num_columns, loss_masking, TargetFormat.logits, batch_shape, dtype) + local_logits = split_op(logits, group, -1).contiguous() out_ref, grad_ref = loss_forward_backward( grad_output, z_loss, logits, - loss_mask, + loss_mask=loss_mask, logits_scale_factor=logits_scale_factor, ) + if accumulate: + previous_grad = torch.randn_like(grad_ref) + grad_ref = grad_ref + previous_grad + local_previous_grad = split_op(previous_grad, group, -1).contiguous() out_fused, grad_fused = fused_z_loss_forward_backward( - split_op(logits, group, -1), - loss_mask, - grad_output, + logits=local_logits, + loss_mask=loss_mask, + grad_logits=local_previous_grad.clone() if accumulate else None, + grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, ) @@ -224,9 +237,10 @@ def _test_z_loss( if not triton_available: return out_triton, grad_triton = triton_z_loss_forward_backward( - split_op(logits, group, -1), - loss_mask, - grad_output, + logits=local_logits, + loss_mask=loss_mask, + grad_logits=local_previous_grad.clone() if accumulate else None, + grad_output=grad_output, group=group, logits_scale_factor=logits_scale_factor, block_size=block_size, @@ -245,7 +259,8 @@ def _test_z_loss( @pytest.mark.slow @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size", "accumulate"), + _LOSS_PARAMETERS, ) @pytest.mark.parametrize("target_format", TargetFormat) @pytest.mark.parametrize("entropy_loss_type", EntropyLossType) @@ -259,6 +274,7 @@ def test_entropy_loss( entropy_loss_type, dtype, block_size, + accumulate, ): _test_entropy_loss( batch_shape, @@ -270,16 +286,22 @@ def test_entropy_loss( entropy_loss_type, dtype, block_size, + accumulate, ) @pytest.mark.slow @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size", "accumulate"), + _LOSS_PARAMETERS, ) -def test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size): - _test_z_loss(batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size) +def test_z_loss( + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, accumulate +): + _test_z_loss( + batch_shape, num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size, accumulate + ) @pytest.mark.skip(reason="DPO loss is broken") @@ -296,8 +318,16 @@ def test_dpo_loss(): def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pathlib.Path, seed: int): for batch_shape in _BATCH_SHAPES: - for num_columns, grad_output, logits_scale_factor, loss_masking, dtype, block_size in _LOSS_PARAMETERS: - suffix = f"{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{block_size}-{"_".join([str(i) for i in batch_shape])}" + for ( + num_columns, + grad_output, + logits_scale_factor, + loss_masking, + dtype, + block_size, + accumulate, + ) in _LOSS_PARAMETERS: + suffix = f"{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{block_size}-{accumulate}-{"_".join([str(i) for i in batch_shape])}" # Entropy loss for entropy_loss_type in EntropyLossType: for target_format in TargetFormat: @@ -318,6 +348,7 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa entropy_loss_type, dtype, block_size, + accumulate, test_context.group, ) # Z loss @@ -332,6 +363,7 @@ def _run_lm_loss_distributed(test_context: DistributedTestContext, base_path: pa loss_masking, dtype, block_size, + accumulate, test_context.group, ) @@ -360,7 +392,8 @@ def test_run_lm_loss_distributed(run_parallel_script, result_path): @pytest.mark.depends_on(on=["test_lm_loss_distributed_dependency"]) @pytest.mark.parametrize("batch_shape", _BATCH_SHAPES) @pytest.mark.parametrize( - ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size"), _LOSS_PARAMETERS + ("num_columns", "grad_output", "logits_scale_factor", "loss_masking", "dtype", "block_size", "accumulate"), + _LOSS_PARAMETERS, ) @pytest.mark.parametrize( "loss_type", @@ -385,10 +418,11 @@ def test_lm_loss_distributed( loss_masking, dtype, block_size, + accumulate, ): report_subtest( result_path - / f"test_losses/{loss_type}-{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{block_size}-{"_".join([str(i) for i in batch_shape])}", + / f"test_losses/{loss_type}-{num_columns}-{grad_output}-{logits_scale_factor}-{loss_masking}-{dtype}-{block_size}-{accumulate}-{"_".join([str(i) for i in batch_shape])}", 2, use_cuda=False, ) From 945dadc9ba644dd86dcde8cc86191693c9b30a09 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Feb 2026 13:22:44 -0500 Subject: [PATCH 17/22] Token dim --- fast_llm/engine/base_model/base_model.py | 1 + fast_llm/engine/schedule/runner.py | 24 ++- fast_llm/functional/autograd.py | 4 +- fast_llm/layers/attention/attention.py | 78 +++----- fast_llm/layers/attention/rotary/rotary.py | 1 + fast_llm/layers/block/block.py | 23 +-- fast_llm/layers/block/config.py | 5 +- fast_llm/layers/decoder/block.py | 168 +++++----------- .../layers/decoder/mlp/mixture_of_experts.py | 34 +--- fast_llm/layers/decoder/mlp/mlp.py | 15 +- fast_llm/layers/language_model/config.py | 10 - fast_llm/layers/language_model/embedding.py | 55 ++--- fast_llm/layers/language_model/head.py | 87 ++++---- .../language_model/loss/entropy_loss.py | 1 + fast_llm/layers/language_model/loss/loss.py | 24 ++- .../language_model/multi_token_prediction.py | 4 +- fast_llm/layers/ssm/gdn.py | 7 - fast_llm/layers/ssm/kda.py | 15 -- fast_llm/layers/ssm/mamba.py | 26 +-- fast_llm/layers/vision/embeddings.py | 11 +- fast_llm/models/gpt/huggingface.py | 14 +- fast_llm/models/gpt/model.py | 188 +++++++++--------- fast_llm/models/multimodal/model.py | 57 +++--- tests/layers/test_lm_head.py | 55 +++-- tests/layers/test_ssm.py | 2 - tests/layers/test_varlen.py | 13 +- tests/test_loss_mask.py | 8 +- tests/utils/distributed_configs.py | 32 +-- tests/utils/model_configs.py | 4 + 29 files changed, 368 insertions(+), 598 deletions(-) diff --git a/fast_llm/engine/base_model/base_model.py b/fast_llm/engine/base_model/base_model.py index ffffbed50..de64d905a 100644 --- a/fast_llm/engine/base_model/base_model.py +++ b/fast_llm/engine/base_model/base_model.py @@ -179,6 +179,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + extra_kwargs: dict[str, typing.Any] | None = None, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase pass diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 4d31324fe..2d7e02f77 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -19,6 +19,7 @@ from fast_llm.engine.optimizer.optimizer import Optimizer from fast_llm.engine.schedule.config import EventType, MockEvent, MockStream, ScheduleConfig, StepType, StreamType from fast_llm.engine.schedule.schedule import Schedule, Step +from fast_llm.layers.block.config import BlockKwargs from fast_llm.logging import log_memory_usage from fast_llm.utils import Assert @@ -339,15 +340,15 @@ def _preprocess_data( phase=context.phase, iteration=context.iteration, metrics=context.metrics, + extra_kwargs={ + "grad_output": grad_output, + "micro_batch": micro_batch, + "num_micro_batches": batch_config.sequential_micro_batches, + "micro_batch_splits": batch_config.micro_batch_splits, + }, ) for micro_batch_split, (input_, kwargs) in enumerate(micro_batch_data): - kwargs.update( - grad_output=grad_output, - micro_batch=micro_batch, - micro_batch_split=micro_batch_split, - num_micro_batches=batch_config.sequential_micro_batches, - micro_batch_splits=batch_config.micro_batch_splits, - ) + kwargs.update(micro_batch_split=micro_batch_split) data_index = context.schedule.get_data_index(micro_batch, micro_batch_split) if self._stages_owned[0]: context.inputs[context.schedule.get_step(StepType.forward, 0, data_index).global_index] = input_ @@ -405,6 +406,15 @@ def _recv(self, context: BatchContext, step: Step) -> None: self._record_event(context, EventType.compute_wait_pipe, step) def _forward(self, context: BatchContext, step: Step) -> None: + print( + "IASINBUI", + step, + ( + context.batch[step.data_index].get(BlockKwargs.grad_output) + if step.data_index in context.batch + else "PPPPP" + ), + ) output, grad_context = self._stages[step.stage].forward( self._get_forward_input(context, step), context.batch[step.data_index], diff --git a/fast_llm/functional/autograd.py b/fast_llm/functional/autograd.py index 3e8e31cea..586f833b3 100644 --- a/fast_llm/functional/autograd.py +++ b/fast_llm/functional/autograd.py @@ -62,8 +62,8 @@ def grad_is_context(grad_output: torch.Tensor, context: torch.Tensor) -> torch.T class AuxiliaryLoss(torch.autograd.Function): @staticmethod - def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float) -> torch.Tensor: # noqa - ctx.grad = torch.full_like(aux_loss, grad) + def forward(ctx, input_: torch.Tensor, aux_loss: torch.Tensor, grad: float | None) -> torch.Tensor: # noqa + ctx.grad = None if grad is None else torch.full_like(aux_loss, grad) return input_ @staticmethod diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index d6eab0eb2..be8b31f39 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -3,7 +3,7 @@ import torch from fast_llm.core.distributed import set_generator -from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op, swap_mult_dim +from fast_llm.core.ops import gather_op, reduce_op, reduce_scatter_op from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ @@ -12,7 +12,6 @@ from fast_llm.functional.autograd import wrap_forward_backward from fast_llm.layers.attention.config import AttentionConfig, AttentionImplementation, AttentionKwargs from fast_llm.layers.attention.preprocessing import preprocess_for_varlen -from fast_llm.layers.block.config import BlockDimNames from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.tensor import TensorMeta @@ -113,7 +112,7 @@ def __init__( CompositeTensorDim("value", (head_group_dim, head_size_dim)), ), ) - dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim)) + self._dense_dim = CompositeTensorDim("dense", (head_group_dim, group_heads_dim, head_size_dim)) self._softmax_scale = self._config.head_size ** (-self._config.softmax_scale_power) @@ -152,7 +151,7 @@ def __init__( # Output. self.dense = self._config.dense_layer.get_layer( - dense_dim, + self._dense_dim, hidden_dim, default_weight_initialization=init_normal_(std=self._hidden_size**-0.5), default_add_bias=self._config.add_linear_biases, @@ -163,22 +162,13 @@ def __init__( # Debug dims self._query_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, CompositeTensorDim("heads", (head_group_dim, group_heads_dim)), head_size_dim, ) self._kv_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, head_group_dim, head_size_dim, ) - self._context_dims = ( - BlockDimNames.batch, - BlockDimNames.sequence_q, - dense_dim, - ) def _attn_backup( self, @@ -269,7 +259,7 @@ def _attn_flash( ) def _query_key_value_forward( - self, input_: torch.Tensor, sequence_first: bool + self, input_: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, dict[str, typing.Any]]: key_value, key_value_context = self.key_value.forward_only(input_) @@ -292,10 +282,7 @@ def _query_key_value_forward( if handle: handle.wait() - if self._sequence_data_parallel_dim.group and not sequence_first: - key_value = swap_mult_dim(key_value, self._sequence_parallel, 0, 1) - - context = {"query": query_context, "key_value": key_value_context, "sequence_first": sequence_first} + context = {"query": query_context, "key_value": key_value_context} return query, key_value, context def _query_key_value_backward( @@ -305,7 +292,7 @@ def _query_key_value_backward( key_value_grad, handle = reduce_scatter_op( key_value_grad, group=self._sequence_data_parallel_dim.group, - dim=1 - context["sequence_first"], + dim=0, async_op=True, ) @@ -331,15 +318,20 @@ def _forward( losses: dict[str, typing.Any] | None = None, metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: - sequence_first = kwargs[AttentionKwargs.sequence_first] - query, key_value = self._query_key_value(input_, sequence_first) + query, key_value = self._query_key_value(input_) + + # Separate the batch and sequence dimensions + token_dims = (kwargs[AttentionKwargs.batch_dim], kwargs[AttentionKwargs.sequence_q_dim]) + token_shape = tuple(dim.size for dim in token_dims) + query = query.unflatten(0, token_shape) + key_value = key_value.unflatten(0, token_shape) # TODO: Move the rest to function. if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: - assert sequence_first # Clear the lists so tensors can be de-allocated - key_value = torch.cat((past_key_values.pop(0), key_value), dim=0) + # TODO: ===== Check ===== + key_value = torch.cat((past_key_values.pop(0), key_value), dim=1) if (presents := kwargs.get(AttentionKwargs.presents)) is not None: # Return the presents as a leaf tensors so the gradients from later micro-sequences @@ -348,26 +340,17 @@ def _forward( # Manually add the gradients from later micro-sequences. key_value = AttachGrad.apply(key_value, present) - if self._sequence_data_parallel_dim.group: - key_value = ( - key_value[: kwargs[AttentionKwargs.sequence_k_dim].size] - if sequence_first - else key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] - ) - - if sequence_first: - # TODO: Optimize (is contiguous avoidable?) - query = query.transpose(0, 1).contiguous() - key_value = key_value.transpose(0, 1).contiguous() - + # TODO: ===== Check ===== + key_value = key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] key, value = key_value.split(self._local_head_groups * self._config.head_size, dim=-1) + # TODO: ===== Expand batch seq dim ===== query = query.view(*query.shape[:2], self._local_heads, self._config.head_size) key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size) value = value.view(*value.shape[:2], self._local_head_groups, self._config.head_size) - self._debug(query, "query_rotary_input", self._query_dims, kwargs) - self._debug(key, "key_rotary_input", self._kv_dims, kwargs) + self._debug(query, "query_rotary_input", token_dims + self._query_dims, kwargs) + self._debug(key, "key_rotary_input", token_dims + self._kv_dims, kwargs) query, key = self._rotary(query, key, kwargs) with set_generator(self._distributed.tp_generator): @@ -379,22 +362,17 @@ def _forward( else: raise NotImplementedError(self._implementation) - self._debug(query, "query", self._query_dims, kwargs) - self._debug(key, "key", self._kv_dims, kwargs) - self._debug(value, "value", self._kv_dims, kwargs) - self._debug(input_, "context", self._context_dims, kwargs) + self._debug(query, "query", token_dims + self._query_dims, kwargs) + self._debug(key, "key", token_dims + self._kv_dims, kwargs) + self._debug(value, "value", token_dims + self._kv_dims, kwargs) + self._debug(input_, "context", token_dims + (self._dense_dim,), kwargs) - if sequence_first: - # TODO: Optimize (is contiguous avoidable? Transpose dense output?) - input_ = input_.transpose(0, 1).contiguous() - out, bias = self.dense(input_) - self._debug(out, None, kwargs.get(AttentionKwargs.hidden_dims), kwargs) + out, bias = self.dense(input_.flatten(0, 1)) + self._debug(out, None, token_dims + (self._hidden_dim,), kwargs) return out, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - batch_dim: TensorDim = kwargs[AttentionKwargs.hidden_dims][1 if kwargs[AttentionKwargs.sequence_first] else 0] - - # Using this one since `hidden_dims` may be sequence-tensor-parallel, and attention is not. + batch_dim: TensorDim = kwargs[AttentionKwargs.batch_dim] sequence_q_dim: TensorDim = kwargs[AttentionKwargs.sequence_q_dim] sequence_k_dim: TensorDim = kwargs[AttentionKwargs.sequence_k_dim] @@ -435,7 +413,7 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c partly_out_of_window = max(sequence_k - fully_out_of_window - self._config.window_size, 0) attention_compute -= (partly_out_of_window * (partly_out_of_window + 1) * attn_compute_base) // 2 - dense_input = TensorMeta.from_dims((batch_dim, sequence_q_dim, self._context_dims[-1])) + dense_input = TensorMeta.from_dims((*input_.dims[:-1], self._dense_dim)) # TODO: Add marginal compute? (ex. softmax) return sum( diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index 307256a72..e24d85e36 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -235,6 +235,7 @@ def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if TritonConfig.enabled(query.device) else rotary_embeddings_real + print("AAAAA", query.shape, kwargs[AttentionKwargs.rotary_freq_q].shape) query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key diff --git a/fast_llm/layers/block/block.py b/fast_llm/layers/block/block.py index a1942cab1..dc7334b45 100644 --- a/fast_llm/layers/block/block.py +++ b/fast_llm/layers/block/block.py @@ -47,7 +47,7 @@ def __call__( if bias is not None: assert tensor is not None tensor = tensor + bias - meta = self._get_meta(tensor, name, dims, kwargs) + meta = self._get_meta(tensor, name, dims) if output_hidden_state: kwargs[BlockKwargs.hidden_states][name] = (meta, tensor) @@ -68,7 +68,7 @@ def __call__( "", tensor, level=level, - meta=self._get_meta(tensor, name + f"{name}.grad", dims, kwargs), + meta=self._get_meta(tensor, name + f"{name}.grad", dims), **logging_kwargs, ) @@ -76,26 +76,13 @@ def _get_meta( self, tensor: torch.Tensor | None, name: str, - dims: tuple[TensorDim | str, ...] | None, - kwargs: dict[str, typing.Any], - ) -> TensorMeta | None: - if tensor is None: - return None + dims: tuple[TensorDim | str | None, ...] | None, + ) -> TensorMeta: if dims is None: dims = tuple(f"dim_{i}" for i in range(tensor.ndim)) - hidden_dims = {} - if BlockKwargs.hidden_dims in kwargs: - for dim in kwargs[BlockKwargs.hidden_dims]: - hidden_dims[dim.name] = dim - if BlockKwargs.sequence_q_dim in kwargs: - hidden_dims[kwargs[BlockKwargs.sequence_q_dim].name] = kwargs[BlockKwargs.sequence_q_dim] return TensorMeta.from_dims( tuple( - ( - dim - if isinstance(dim, TensorDim) - else hidden_dims[dim] if dim in hidden_dims else TensorDim(dim, tensor.size(i)) - ) + (dim if isinstance(dim, TensorDim) else TensorDim(f"dim_{i}" if dim is None else dim, tensor.size(i))) for i, dim in enumerate(dims) ), tensor_name=name, diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index fd76d36cb..a1b600445 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -31,10 +31,11 @@ class BlockDimNames: class BlockKwargs: - sequence_first = "sequence_first" - hidden_dims = "hidden_dims" + batch_dim = "batch_dim" sequence_q_dim = "sequence_q_dim" sequence_k_dim = "sequence_k_dim" + token_dim = "token_dim" + hidden_token_dim = "hidden_token_dim" # TODO: These are confusing sequence_length = "sequence_length" sequence_lengths = "sequence_lengths" diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index 8f6e360fd..bb9df3fb9 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -126,143 +126,77 @@ def forward( metrics: dict[str, typing.Any] | None = None, ) -> torch.Tensor: if isinstance(input_, TensorMeta): - dims = kwargs[BlockKwargs.hidden_dims] + dims = input_.dims if self._return_input: dims = (TensorDim("stacked_input_output", 2),) + dims return TensorMeta.from_dims(dims, tensor_name=f"{self.module_name} output", dtype=input_.dtype) generator = self._distributed.tp_generator if self._sequence_parallel else self._distributed.pp_generator - self._debug(None, "begin", kwargs.get(BlockKwargs.hidden_dims), kwargs) + hidden_dims = (kwargs.get(BlockKwargs.hidden_token_dim), self._hidden_dim) + self._debug(None, "begin", hidden_dims, kwargs) fw_input = input_ hidden_states = self.norm_1(input_) - self._debug(hidden_states, "norm_1", kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(hidden_states, "norm_1", hidden_dims, kwargs) hidden_states, bias = self.mixer(hidden_states, kwargs, metrics=metrics) - hidden_states, bias = self.activation_distillation_loss(hidden_states, bias, kwargs, losses, metrics) + self._debug(hidden_states.detach(), "mixer_output", hidden_dims, kwargs, bias=bias) + if self._config.distillation_model is not None and self.training: + if bias is not None: + hidden_states = hidden_states + bias + bias = None + hidden_states = self._activation_distillation_loss(hidden_states, kwargs, losses, metrics) with set_generator(generator): input_ = self._bias_dropout_add(hidden_states, bias, input_) - self._debug(input_, "mixer_residual", kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(input_, "mixer_residual", hidden_dims, kwargs) hidden_states = self.norm_2(input_) - self._debug(hidden_states, "norm_2", kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(hidden_states, "norm_2", hidden_dims, kwargs) hidden_states, bias = self.mlp(hidden_states, kwargs, losses, metrics) with set_generator(generator): hidden_states = self._bias_dropout_add(hidden_states, bias, input_) - self._debug(hidden_states, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(hidden_states, None, hidden_dims, kwargs) if self._return_input: hidden_states = torch.stack((fw_input, hidden_states), dim=0) return hidden_states - def activation_distillation_loss(self, hidden_states, bias, kwargs, losses, metrics): - """ - Maybe apply activation distillation loss and setup backward hooks. - """ - mixer_output = hidden_states if bias is None else hidden_states + bias - - # Teacher: output mixer activations via _debug interface - self._debug(mixer_output.detach(), "mixer_output", kwargs.get(BlockKwargs.hidden_dims), kwargs) - - # Student gets teacher activations and computes the activation-level loss. - activation_targets = kwargs.get(BlockKwargs.activation_distillation_targets) - key = f"{self.module_name}.mixer_output" - if ( - activation_targets is not None - and self.training - and (teacher_output := activation_targets.pop(key, None)) is not None - ): - # Compare student mixer output with the teacher's stored activation and accumulate the loss. - teacher_tensor = teacher_output.detach().to(device=mixer_output.device, dtype=mixer_output.dtype) - Assert.eq(teacher_tensor.shape, mixer_output.shape) - # TODO: un-scaled loss for reporting? Average loss over layers? - # L2 loss - activation_loss_factor = self._config.distillation_loss_weight - # (batch, sequence, hidden) or (sequence, batch, hidden). Take the norm over hidden dim. - - # Handle possible padding by using pre-computed activation mask - sequence_first = kwargs.get(BlockKwargs.sequence_first, False) - activation_mask = kwargs.get(BlockKwargs.activation_mask) - - if activation_mask is not None: - # Use pre-computed activation mask (bool tensor where True = valid token) - mask = activation_mask.to(dtype=mixer_output.dtype) - if sequence_first: - # (batch, sequence) -> (sequence, batch) - mask = mask.T - - # Compute masked L2 loss: norm over hidden dim, then apply mask - per_token_loss = torch.norm( - mixer_output - teacher_tensor, p=2, dim=-1 - ) # (batch, sequence) or (sequence, batch) - - # Slice mask to match per_token_loss shape (for sequence parallelism) - # When sequence_tensor_parallel is enabled, per_token_loss only has local sequence length - if mask.shape != per_token_loss.shape: - # Calculate the sequence offset for this rank using the hidden_dims parallel rank - hidden_dims = kwargs.get(BlockKwargs.hidden_dims) - seq_dim_idx = 0 if sequence_first else 1 - hidden_seq_dim = hidden_dims[seq_dim_idx] if hidden_dims else None - - if hidden_seq_dim and hidden_seq_dim.parallel_dim: - # Use the rank from the actual parallel dimension used by hidden states - local_seq_length = per_token_loss.shape[0] if sequence_first else per_token_loss.shape[1] - seq_offset = hidden_seq_dim.parallel_dim.rank * local_seq_length - else: - seq_offset = 0 - - if sequence_first: - # mask: (sequence, batch), per_token_loss: (local_sequence, batch) - mask = mask[seq_offset : seq_offset + per_token_loss.shape[0], :] - else: - # mask: (batch, sequence), per_token_loss: (batch, local_sequence) - mask = mask[:, seq_offset : seq_offset + per_token_loss.shape[1]] - - masked_loss = per_token_loss * mask - local_loss_sum = torch.sum(masked_loss) - total_count = int(mask.sum().item()) - else: - # No activation_mask available, compute loss on all tokens - per_token_loss = torch.norm( - mixer_output - teacher_tensor, p=2, dim=-1 - ) # (batch, sequence) or (sequence, batch) - local_loss_sum = torch.sum(per_token_loss) - # mixer_output.shape is (batch, sequence, hidden) or (sequence, batch, hidden) - # In either case, dims 0 and 1 are batch and sequence - total_count = mixer_output.shape[0] * mixer_output.shape[1] - - # All-reduce across tensor-parallel group if sequence-parallel is enabled - if self._sequence_parallel and self._distributed.tensor_group is not None: - all_reduce(local_loss_sum, group=self._distributed.tensor_group, op=ReduceOp.SUM) - if activation_mask is not None: - # Different ranks may have different amounts of padding - total_count_tensor = torch.tensor(total_count, device=mixer_output.device, dtype=torch.int64) - all_reduce(total_count_tensor, group=self._distributed.tensor_group, op=ReduceOp.SUM) - total_count = int(total_count_tensor.item()) - else: - # All ranks contribute the same count - total_count *= self._distributed.tensor_group.size() - - activation_loss = local_loss_sum / total_count - scaled_activation_loss = activation_loss_factor * activation_loss - - # Backward hooks - hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, 1.0) - bias = AuxiliaryLoss.apply(bias, scaled_activation_loss, 1.0) if bias is not None else None - # Logging - if losses is not None and self._distillation_loss_name in losses: - losses[self._distillation_loss_name].append(activation_loss.detach()) - # Per-layer metrics - if metrics is not None: - metrics[f"{self.module_name}/activation_distillation_loss"] = activation_loss.detach() - - # If using stochastic mixer, also log per-mixer-type activation distillation loss - from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer - - if isinstance(self.mixer, StochasticMixer): - selected_mixer = self.mixer._last_selected_mixer - metrics[f"{self.module_name}/activation_distillation_loss/{selected_mixer}"] = ( - activation_loss.detach() - ) - return hidden_states, bias + def _activation_distillation_loss(self, hidden_states, kwargs, losses, metrics): + Assert.incl( + mixer_output_name := f"{self.module_name}.mixer_output", + reference_hidden_states := kwargs[f"reference_{self._config.distillation_model}_hidden_states"], + ) + teacher_hidden_states = reference_hidden_states.pop(mixer_output_name) + + # L2 loss + per_token_loss = torch.norm(hidden_states - teacher_hidden_states, dim=-1, dtype=torch.float32) + if (activation_mask := kwargs.get(BlockKwargs.activation_mask)) is not None: + per_token_loss = per_token_loss * activation_mask + loss = torch.mean(per_token_loss) + + # All-reduce across tensor-parallel group if sequence-parallel is enabled + if self._sequence_parallel and self._distributed.tensor_group is not None: + all_reduce(loss, group=self._distributed.tensor_group, op=ReduceOp.AVG) + + scaled_activation_loss = self._config.distillation_loss_weight * loss + + # Backward hook + print(kwargs[BlockKwargs.grad_output]) + hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, kwargs.get(BlockKwargs.grad_output)) + + # Logging + if losses is not None and self._distillation_loss_name in losses: + losses[self._distillation_loss_name].append(loss.detach()) + + if metrics is not None: + metrics[f"{self.module_name}/activation_distillation_loss"] = loss.detach() + + # If using stochastic mixer, also log per-mixer-type activation distillation loss + from fast_llm.layers.decoder.stochastic_mixer import StochasticMixer + + if isinstance(self.mixer, StochasticMixer): + metrics[f"{self.module_name}/activation_distillation_loss/{self.mixer._last_selected_mixer}"] = ( + loss.detach() + ) + return hidden_states def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (normalization, bias_dropout_add) diff --git a/fast_llm/layers/decoder/mlp/mixture_of_experts.py b/fast_llm/layers/decoder/mlp/mixture_of_experts.py index 413a88ed6..13ba79a7a 100644 --- a/fast_llm/layers/decoder/mlp/mixture_of_experts.py +++ b/fast_llm/layers/decoder/mlp/mixture_of_experts.py @@ -12,7 +12,6 @@ from fast_llm.functional.autograd import AuxiliaryLoss from fast_llm.functional.triton.mlp import mlp_autograd, mlp_autograd_looped from fast_llm.functional.triton.sparse_copy import get_sparse_map -from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.mlp.config import MLPLossNames, MoEMLPConfig, RoutingType @@ -94,11 +93,8 @@ def _forward( return TensorMeta.from_dims(input_.dims[:-1] + (self._output_dim,), "MLP output"), None hidden_states = input_.flatten(0, -2) logits = self.router(hidden_states) - logit_dims = ( - kwargs[BlockKwargs.hidden_dims][:-1] + (self._top_expert_dim,) - if BlockKwargs.hidden_dims in kwargs - else None - ) + hidden_token_dim = kwargs[BlockKwargs.hidden_token_dim] + logit_dims = (hidden_token_dim, self._top_expert_dim) self._debug(logits, "Router logits", logit_dims, kwargs) # Apply z_loss if applicable @@ -130,7 +126,7 @@ def _forward( self._debug(top_experts, "router_top_experts", logit_dims, kwargs) out = self._mlp_forward(hidden_states, scores, top_experts).view_as(input_) # noqa - self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + self._debug(out, None, (hidden_token_dim, self._hidden_dim), kwargs) return out, None def _forward_dropless( @@ -241,24 +237,14 @@ def _sinkhorn_activation(self, logits: torch.Tensor) -> torch.Tensor: ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: - if kwargs[AttentionKwargs.sequence_first]: - sequence_dim, batch_dim, hidden_dim = input_.dims - else: - batch_dim, sequence_dim, hidden_dim = input_.dims - - # Applying the tokens per expert on the batch dim so the super() call works as intended. - moe_batch_dim = TensorDim( - f"moe_{batch_dim.name}", batch_dim.global_size * self._config.experts_per_token, batch_dim.parallel_dim + token_dim, hidden_dim = input_.dims + # Applying the tokens per expert on the token dim so the super() call works as intended. + moe_token_dim = TensorDim( + f"moe_{token_dim.name}", token_dim.global_size * self._config.experts_per_token, token_dim.parallel_dim + ) + moe_input = TensorMeta.from_dims( + (moe_token_dim, hidden_dim), tensor_name=f"moe_{input_.tensor_name}", dtype=input_.dtype ) - - if kwargs[AttentionKwargs.sequence_first]: - dims = sequence_dim, moe_batch_dim, hidden_dim - else: - dims = moe_batch_dim, sequence_dim, hidden_dim - - # Also adjust the dtype in case of full-precision residual - moe_input = TensorMeta.from_dims(dims, tensor_name=f"moe_{input_.tensor_name}", dtype=input_.dtype) - return super().get_compute_usage(moe_input, kwargs, config) + self.router.get_compute_usage(input_, config) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: diff --git a/fast_llm/layers/decoder/mlp/mlp.py b/fast_llm/layers/decoder/mlp/mlp.py index 88c86c8aa..1048f7c2a 100644 --- a/fast_llm/layers/decoder/mlp/mlp.py +++ b/fast_llm/layers/decoder/mlp/mlp.py @@ -9,7 +9,6 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.config import TritonConfig from fast_llm.functional.triton.mlp import mlp_autograd, torch_mlp_activation, triton_mlp_activation_autograd -from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias @@ -85,16 +84,11 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c if config.hardware and self._config.recompute_level.recompute_layer_1 else config ) - - # Get the layer 2 input dims, accounting for ordering and possible sequence-parallelism. - # TODO: Don't rely on kwargs dimensions. - if kwargs[AttentionKwargs.sequence_first]: - dims = (kwargs[AttentionKwargs.sequence_q_dim], input_.dims[1], self._intermediate_2_dim) - else: - dims = (input_.dims[0], kwargs[AttentionKwargs.sequence_q_dim], self._intermediate_2_dim) # Also adjust the dtype in case of full-precision residual layer_2_input = TensorMeta.from_dims( - dims, tensor_name="intermediate_1", dtype=self._distributed_config.compute_dtype.torch + (input_.dims[0], self._intermediate_2_dim), + tensor_name="intermediate_1", + dtype=self._distributed_config.compute_dtype.torch, ) # TODO: Add marginal compute? (ex. activation, gate + up) @@ -141,6 +135,5 @@ def _forward( bias = self.layer_2.bias if self._parallel_dim.group else None # Use None for dims when output_dim differs from hidden_dim (e.g., adapter projections) # to let _debug infer dims from actual tensor shape - dims = None if self._output_dim != self._hidden_dim else kwargs.get(BlockKwargs.hidden_dims) - self._debug(out, None, dims, kwargs, bias=bias) + self._debug(out, None, (kwargs.get(BlockKwargs.hidden_token_dim), self._hidden_dim), kwargs, bias=bias) return out, bias diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index 9ba1f3433..e3446bba6 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -146,8 +146,6 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): desc="Configuration for the LM output layer (weight). Ignored for tied embeddings", hint=FieldHint.architecture, ) - # TODO: Option to chose whether to split in batch or sequence dimension? - # (Currently split merged batch and sequence, depends on `sequence_first`) cross_entropy_splits: int = Field( default=1, desc="Split the logit and cross-entropy computation into this many fragment, to reduce memory usage.", @@ -274,14 +272,6 @@ class LanguageModelConfig(BlockConfig): hint=FieldHint.architecture, valid=check_field(Assert.gt, 0), ) - sequence_first: bool | None = Field( - default=None, - desc="Override the default dimension ordering", - doc="By default, the hidden states are stored with dimensions (batch, sequence, ...), as it makes attention more efficient." - " However, some settings such as sequence-tensor/data/pipelineo-parallel instead require the ordering (sequence, batch, ...)." - " Setting this parameter overrides the default choice. Note that setting to `False` will either do nothing or raise an error.", - hint=FieldHint.testing, - ) @property def layer_class(self) -> "type[LanguageModel]": diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index 93850d24c..c6df8f62b 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -8,6 +8,7 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames +from fast_llm.layers.attention.preprocessing import preprocess_for_varlen from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelKwargs @@ -28,11 +29,6 @@ class LanguageModelEmbedding[ConfigType: LanguageModelEmbeddingsConfig](Block[Co layer_count: float = 1000.0 _config: ConfigType - # Preprocessing - _rotary_embedding_frequencies: torch.Tensor - _position_ids: torch.Tensor - _tensor_cache_max_sequence_length: int = -1 - def __init__( self, config: ConfigType, @@ -84,7 +80,7 @@ def _forward( token_ids: torch.Tensor, position_ids: torch.Tensor | None, mask_inputs: bool, - embedding_map: tuple[torch.Tensor, torch.Tensor] | None, + embedding_map: torch.Tensor, ) -> torch.Tensor: Assert.eq(position_ids is None, self.position_embeddings_weight is None) group = self._parallel_dim.group @@ -102,7 +98,7 @@ def _forward( if self._sequence_parallel: input_ = gather(input_, group=group, dim=0) # Out-of-place equivalent of `embeddings[embedding_map] += input_` - embeddings = embeddings.index_put(embedding_map, input_[: embedding_map[0].size(0)], accumulate=True) + embeddings = embeddings.index_put((embedding_map,), input_[: embedding_map.size(0)], accumulate=True) if self._sequence_parallel: embeddings = split(embeddings, group=group, dim=0) @@ -122,7 +118,7 @@ def _forward( if self.position_embeddings_weight is not None: embeddings = embeddings + torch.nn.functional.embedding(position_ids, self.position_embeddings_weight) if mask_inputs: - embeddings = embeddings * token_mask.unsqueeze(2) + embeddings = embeddings * token_mask.unsqueeze(-1) if input_ is not None: # TODO: Accumulate redundant with masking? @@ -131,12 +127,12 @@ def _forward( input_ = gather(input_, group=group, dim=0) embeddings_ = embeddings.new_zeros(embeddings.shape[0] * group.size(), *embeddings.shape[1:]) embeddings_ = embeddings_.index_put( - embedding_map, input_[: embedding_map[0].size(0)], accumulate=True + (embedding_map,), input_[: embedding_map.size(0)], accumulate=True ) embeddings = embeddings + split(embeddings_, group=group, dim=0) else: embeddings = embeddings.index_put( - embedding_map, input_[: embedding_map[0].size(0)], accumulate=True + (embedding_map,), input_[: embedding_map.size(0)], accumulate=True ) with set_generator( @@ -154,7 +150,7 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): return TensorMeta.from_dims( - kwargs[LanguageModelKwargs.hidden_dims], + (kwargs[LanguageModelKwargs.hidden_token_dim], self._hidden_dim), tensor_name=f"{self.module_name} output", dtype=self._residual_dtype, ) @@ -167,8 +163,6 @@ def forward( # TODO: Support multiple encoders. # TODO: Support pipeline-parallel. token_ids = kwargs.get(LanguageModelKwargs.token_ids) - # Drop the placeholder batch dimension, remove patch padding. - input_ = input_.squeeze(int(kwargs[LanguageModelKwargs.sequence_first])) out = self._forward( input_, @@ -178,7 +172,7 @@ def forward( kwargs.get(LanguageModelKwargs.mask_inputs), embedding_map, ) - self._debug(out, None, kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) + self._debug(out, None, (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._hidden_dim), kwargs) return out def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: @@ -188,29 +182,12 @@ def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], c def preprocess(self, kwargs: dict[str, typing.Any]) -> None: if not self._config.position_embeddings.enabled: return - self._create_position_embeddings(kwargs[LanguageModelKwargs.sequence_length], self._distributed.device) - sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size - sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size - if not self._config.cross_document_position_embeddings: - position_ids = torch.stack( - [ - torch.cat([torch.arange(x) for x in sample_lens]) - for sample_lens in kwargs[LanguageModelKwargs.sequence_lengths] - ] - ).to(self._distributed.device, dtype=torch.int64) - position_ids = position_ids[:, sequence_k - sequence_q : sequence_k] - if kwargs[LanguageModelKwargs.sequence_first]: - position_ids = position_ids.transpose(0, 1) - kwargs[LanguageModelKwargs.position_ids] = position_ids + # TODO: Move to data preprocessing. + if self._config.cross_document_position_embeddings: + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + sequence_k = kwargs[LanguageModelKwargs.sequence_k_dim].size + kwargs[LanguageModelKwargs.position_ids] = torch.arange( + sequence_k - sequence_q, sequence_k, device=self._distributed.device, dtype=torch.int64 + ).repeat(kwargs[LanguageModelKwargs.batch_dim].size) else: - kwargs[LanguageModelKwargs.position_ids] = self._position_ids[ - sequence_k - sequence_q : sequence_k - ].unsqueeze(int(kwargs[LanguageModelKwargs.sequence_first])) - - def _create_position_embeddings(self, sequence_length: int, device: torch.device) -> None: - if sequence_length <= self._tensor_cache_max_sequence_length: - return - self._tensor_cache_max_sequence_length = sequence_length - - Assert.leq(sequence_length, self._config.num_position_embeddings) - self._position_ids = torch.arange(0, sequence_length, device=device, dtype=torch.int64) + preprocess_for_varlen(kwargs, self._distributed.device, return_position_ids=True) diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index 144074ca5..c5bf9ff9b 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -7,7 +7,6 @@ from torch._C._distributed_c10d import ReduceOp # noqa from torch.distributed import all_reduce -from fast_llm.core.ops import gather_op from fast_llm.engine.base_model.config import LossDef, ResourceUsageConfig from fast_llm.engine.config_utils.data_type import DataType from fast_llm.engine.config_utils.initialization import init_normal_ @@ -15,8 +14,7 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import Block -from fast_llm.layers.block.config import BlockDimNames +from fast_llm.layers.block.block import Block, BlockBase from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LM_HEAD_LOSS_NAME, @@ -34,13 +32,15 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](Block[ConfigType]): +class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](BlockBase[ConfigType]): + heads: "list[LanguageModelHead]" + @abc.abstractmethod def get_output_weights(self) -> list[torch.Tensor]: pass -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType]): +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType], Block): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -99,19 +99,21 @@ def __init__( loss_configs = ( self._config.losses if self._config.losses else {"cross_entropy": LanguageModelLabelEntropyLossConfig()} ) - self._losses = [ - loss_config.get_layer( - distributed_config, - self._get_full_loss_name(name), - self._prediction_distance, - self._prediction_heads, - self._vocab_parallel, - self._config.cross_entropy_splits, - self._config.logits_scale_factor, - self._loss_coefficient, - ) - for name, loss_config in loss_configs.items() - ] + self.losses = torch.nn.ModuleList( + [ + loss_config.get_layer( + distributed_config, + self._get_full_loss_name(name), + self._prediction_distance, + self._prediction_heads, + self._vocab_parallel, + self._config.cross_entropy_splits, + self._config.logits_scale_factor, + self._loss_coefficient, + ) + for name, loss_config in loss_configs.items() + ] + ) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: # TODO: Add marginal compute? (loss) @@ -168,7 +170,12 @@ def _forward_backward( ln_output = self.final_norm(input_) # Transformers expect normalized outputs for the last transformer layer, # so we add the norm output to the hidden states. - self._debug(ln_output, "final_norm", kwargs.get(LanguageModelKwargs.hidden_dims), kwargs) + self._debug( + ln_output, + "final_norm", + (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._hidden_dim), + kwargs, + ) loss, ln_output_grad = self._logits_loss_forward_backward(ln_output.detach(), kwargs, losses) if ln_output_grad is None: return loss, None @@ -185,18 +192,7 @@ def _logits_loss_forward_backward( if not self.training: logits, _ = self._logits_loss_forward_backward_partial(input_, kwargs, return_logits=True) - # TODO: Make a proper way of returning the model output. - logits = logits.detach() - if kwargs.get("global_logits"): - if self._vocab_parallel: - logits = gather_op(logits, self._parallel_dim.group, 2) - elif self._sequence_parallel_logits: - logits = gather_op( - logits, self._parallel_dim.group, 0 if kwargs[LanguageModelKwargs.sequence_first] else 1 - ) - kwargs["logits" if self._prediction_distance == 0 else f"logits_{self._prediction_distance}"] = ( - logits.detach() - ) + self._debug(logits, "logits", (kwargs[LanguageModelKwargs.hidden_token_dim], self._vocab_dim), kwargs) return None, None input_ = input_.flatten(0, -2) @@ -230,7 +226,7 @@ def _logits_loss_forward_backward( total_loss = sum( (loss_.weight / self._config.cross_entropy_splits) * loss_dict[loss_.name] - for loss_ in self._losses + for loss_ in self.losses if loss_.weight != 0.0 and loss_.name in loss_dict ) @@ -240,7 +236,7 @@ def _logits_loss_forward_backward( if all_losses_dict is not None: all_losses_dict[self._total_loss_name].append(total_loss) - if len(self._losses) > 1 or any(loss_.weight != 1.0 for loss_ in self._losses): + if len(self.losses) > 1 or any(loss_.weight != 1.0 for loss_ in self.losses): for name, loss_value in loss_dict.items(): if self._config.cross_entropy_splits != 1: loss_value /= self._config.cross_entropy_splits @@ -265,24 +261,19 @@ def _logits_loss_forward_backward_partial( group=self._parallel_dim.group if self._vocab_parallel else None, sequence_parallel=self._sequence_parallel and self._vocab_parallel, ) - - sequence_dim = BlockDimNames.sequence_q_tp if self._sequence_parallel_logits else BlockDimNames.sequence_q - if LanguageModelKwargs.hidden_dims in kwargs: - batch_dim = kwargs[LanguageModelKwargs.hidden_dims][1 if kwargs[LanguageModelKwargs.sequence_first] else 0] - dims = ( - (sequence_dim, batch_dim, self._vocab_dim) - if kwargs[LanguageModelKwargs.sequence_first] - else (batch_dim, sequence_dim, self._vocab_dim) - ) - else: - dims = None - self._debug(logits, "logits", dims, kwargs, scale=self._config.logits_scale_factor) + self._debug( + logits, + f"logits{"" if self._config.cross_entropy_splits == 1 else f"_{split_index}"}", + (kwargs.get(LanguageModelKwargs.hidden_token_dim), self._hidden_dim), + kwargs, + scale=self._config.logits_scale_factor, + ) if return_logits: return logits, None losses, grad = {}, None - for loss in self._losses: + for loss in self.losses: # losses are returned unscaled but the grads are already scaled loss_value, grad = loss.forward_backward( logits, @@ -304,7 +295,7 @@ def get_loss_definitions(self, count: int = 1) -> list[LossDef]: count=count, dtype=DataType.float32, ) - for loss in self._losses + for loss in self.losses ), ] @@ -316,6 +307,6 @@ def _total_loss_name(self) -> str: return self._get_full_loss_name(LM_HEAD_LOSS_NAME) @property - def heads(self): + def heads(self) -> "list[LanguageModelHead]": # For compatibility with MTP. return [self] diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 537c7996d..d7a76dd83 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -53,6 +53,7 @@ def forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": + print("logits", logits.shape) return ( triton_entropy_loss_forward_backward if TritonConfig.enabled(logits.device, self._config.use_triton) diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 766a5ed54..8dc88c4a1 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -11,7 +11,7 @@ from fast_llm.utils import Assert -class LanguageModelLoss[ConfigType: LanguageModelLossConfig](Configurable[ConfigType]): +class LanguageModelLoss[ConfigType: LanguageModelLossConfig](Configurable[ConfigType], torch.nn.Module): def __init__( self, config: ConfigType, @@ -62,19 +62,17 @@ def _prepare_target( split_index: int = 0, *, multi_token_format: bool = False, + sequence_parallel: bool = True, ) -> torch.Tensor | None: # MTP shift if multi_token_format and self._prediction_heads > 1: - sequence_first: bool = kwargs[LanguageModelLossKwargs.sequence_first] - sequence_q_length = target.size(1 - sequence_first) + 1 - self._prediction_heads - target_slice = slice(self._prediction_distance, self._prediction_distance + sequence_q_length) - target = target[target_slice] if sequence_first else target[:, target_slice] - - # Flatten the batch and sequence dimensions. - target = target.flatten(0, 1) + sequence_q = kwargs[LanguageModelKwargs.sequence_q_dim].size + target = target.unflatten( + 0, (kwargs[LanguageModelKwargs.batch_dim].size, sequence_q + self._prediction_heads - 1) + )[:, self._prediction_distance : self._prediction_distance + sequence_q].flatten(0, 1) # Get the local chunk. - if self._sequence_parallel: + if sequence_parallel and self._sequence_parallel: target = split_op(target, self._parallel_dim.group, 0) # Get the chunk for the current split. @@ -104,7 +102,13 @@ def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): return None if loss_mask is None else self._prepare_target(loss_mask, kwargs, split_index) def _get_reference_model_logits(self, reference_model: str, kwargs: dict[str, typing.Any], split_index: int = 0): - return self._prepare_target(kwargs[f"{reference_model}_logits"], kwargs, split_index) + assert self._prediction_distance == 0 + Assert.incl( + logits_name := self.module_name.rsplit(".", 2)[0] + f".logits", + reference_hidden_states := kwargs[f"reference_{reference_model}_hidden_states"], + ) + # The logits are already sequence-parallel if needed, we don't want to split again. + return self._prepare_target(reference_hidden_states[logits_name], kwargs, split_index, sequence_parallel=False) def loss_forward_backward( diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index ad3395a0f..5efe2d836 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -7,12 +7,12 @@ from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig -from fast_llm.layers.block.block import BlockBase from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, MultiTokenPredictionConfig +from fast_llm.layers.language_model.head import LanguageModelHeadBase -class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](BlockBase[ConfigType]): +class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](LanguageModelHeadBase[ConfigType]): _config: ConfigType def __init__( diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index c7a2c1c59..5e721d424 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -12,7 +12,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import MixerKwargs from fast_llm.layers.attention.preprocessing import preprocess_for_varlen -from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import GatedDeltaNetConfig @@ -301,16 +300,12 @@ def _forward( - """ - sequence_first = kwargs[BlockKwargs.sequence_first] # in sequence parallel TP the input here is already scattered across sequence dimension # TODO: fuse soome of the reshapes into rearranges hidden_states = input_ projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs/seq x seq_len/bs x (qkvz) projected_states_ba = self.in_proj_ba(hidden_states) # bs/seq x seq_len/bs x (b a) - if sequence_first: - projected_states_qkvz = projected_states_qkvz.transpose(0, 1) - projected_states_ba = projected_states_ba.transpose(0, 1) batch_size, sequence_length = projected_states_qkvz.shape[:2] @@ -371,8 +366,6 @@ def _forward( core_attn_out = self.norm(core_attn_out, z) core_attn_out = core_attn_out.reshape(z_shape_og) core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) - if sequence_first: - core_attn_out = core_attn_out.transpose(0, 1) output = self.out_proj(core_attn_out) return output diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 94cde7d5f..07ca3a997 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -11,7 +11,6 @@ from fast_llm.functional.config import ActivationType from fast_llm.layers.attention.config import MixerKwargs from fast_llm.layers.attention.preprocessing import preprocess_for_varlen -from fast_llm.layers.block.config import BlockKwargs from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.decoder.block import BlockWithBias from fast_llm.layers.ssm.config import KimiDeltaAttentionConfig @@ -229,7 +228,6 @@ def _forward( """ Same as in gdn, the idea is to always do forward pass in a packed way, whcih is required for varlen support. """ - sequence_first = kwargs[BlockKwargs.sequence_first] hidden_states = input_ # TODO: can be made more efficeint by rearranging hidden states directly and only once @@ -239,11 +237,6 @@ def _forward( k = self.k_proj(hidden_states) v = self.v_proj(hidden_states) - if sequence_first: - q = q.transpose(0, 1) - k = k.transpose(0, 1) - v = v.transpose(0, 1) - batch_size, sequence_length, _ = q.size() q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) @@ -257,8 +250,6 @@ def _forward( v = self._apply_conv(v, self.v_conv, seq_idx) g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - if sequence_first: - g_kernel = g_kernel.transpose(0, 1) g_kernel = self._reshape_heads(g_kernel) g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) @@ -268,8 +259,6 @@ def _forward( q = self._reshape_heads(q) k = self._reshape_heads(k) v = self._reshape_heads(v) - if sequence_first: - beta = beta.transpose(0, 1) beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md @@ -290,14 +279,10 @@ def _forward( g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim g_out = self._reshape_heads(g_out) - if sequence_first: - g_out = g_out.transpose(0, 1) attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) attn_out = self.norm(attn_out, g_out) attn_out = rearrange(attn_out, "b s h d -> b s (h d)") - if sequence_first: - attn_out = attn_out.transpose(0, 1) attn_out = self.o_proj(attn_out) return attn_out diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index 81b82d08e..fd6255e6c 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -166,16 +166,11 @@ def _forward( ) -> tuple[torch.Tensor, torch.Tensor | None]: assert _mamba_available - # inner_projection : (batch/local_sequence, local_sequence/batch, hidden) - # -> (batch/sequence, sequence/batch, local_inner_projection) - inner_projection = self.in_proj(input_) - dt = self.dt_proj(self.dt_in_proj(input_)) - # Standardize to (batch, sequence, local_inner_projection) - if kwargs[BlockKwargs.sequence_first]: - inner_projection = inner_projection.transpose(0, 1) - dt = dt.transpose(0, 1) - - sequence_length = inner_projection.size(1) + sequence_length = kwargs[BlockKwargs.sequence_q_dim].size + token_shape = (kwargs[BlockKwargs.batch_dim].size, kwargs[BlockKwargs.sequence_q_dim].size) + # inner_projection : (local_tokens, hidden) -> (batch, sequence, local_inner_projection) + inner_projection = self.in_proj(input_).unflatten(0, token_shape) + dt = self.dt_proj(self.dt_in_proj(input_)).unflatten(0, token_shape) z, x, b, c = torch.split( inner_projection, @@ -245,13 +240,10 @@ def _forward( # y: (batch, local_heads * state, sequence) -> (batch, sequence, local_heads * state) y = y.transpose(1, 2)[:, :sequence_length] - if kwargs[BlockKwargs.sequence_first]: - # TODO: Is contiguous needed? - y = y.transpose(0, 1).contiguous() - # (batch/sequence, sequence/batch, local_heads * state) - # -> (batch/local_sequence, local_sequence/batch, hidden) - out, bias = self.out_proj(y) - self._debug(out, None, kwargs.get(BlockKwargs.hidden_dims), kwargs) + # (batch, sequence, local_heads * state) + # -> (local_tokens, hidden) + out, bias = self.out_proj(y.flatten(0, 1)) + self._debug(out, None, (kwargs.get(BlockKwargs.hidden_token_dim), self._hidden_dim), kwargs) return out, bias def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/vision/embeddings.py b/fast_llm/layers/vision/embeddings.py index 2076f72e5..0b0434f56 100644 --- a/fast_llm/layers/vision/embeddings.py +++ b/fast_llm/layers/vision/embeddings.py @@ -6,7 +6,6 @@ from fast_llm.engine.config_utils.initialization import init_normal_ from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames -from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.vision.config import PatchEmbeddingsConfig, VisionKwargs @@ -60,17 +59,13 @@ def forward( ) -> torch.Tensor: if isinstance(input_, TensorMeta): return TensorMeta.from_dims( - kwargs[VisionKwargs.hidden_dims], + (kwargs[VisionKwargs.hidden_token_dim], self._hidden_dim), tensor_name="Patch convolution output", dtype=self._residual_dtype, ) if self._sequence_parallel: input_ = split(input_, group=self._parallel_dim.group, dim=0) - out = ( - self.normalization(self.patch_embeddings(input_.flatten(1))) - .unsqueeze(int(kwargs[AttentionKwargs.sequence_first])) - .to(self._residual_dtype) - ) - self._debug(out, None, kwargs.get(VisionKwargs.hidden_dims), kwargs) + out = self.normalization(self.patch_embeddings(input_.flatten(1))).to(self._residual_dtype) + self._debug(out, None, (kwargs.get(VisionKwargs.hidden_token_dim), self._hidden_dim), kwargs) return out diff --git a/fast_llm/models/gpt/huggingface.py b/fast_llm/models/gpt/huggingface.py index a418c3fb5..387610a46 100644 --- a/fast_llm/models/gpt/huggingface.py +++ b/fast_llm/models/gpt/huggingface.py @@ -135,20 +135,20 @@ def _inner_forward( # The transformers will save the present keys and values to this list. kwargs[AttentionKwargs.presents] = [] - kwargs["global_logits"] = True - self._inference_runner.forward(input_, kwargs, iteration=iteration) # TODO: Make a proper way of returning the model output. - if kwargs[AttentionKwargs.sequence_first]: - logits = kwargs["logits"].transpose(0, 1) - else: - logits = kwargs["logits"] + # TODO: Handle MTP. + logits_meta, logits = kwargs[AttentionKwargs.hidden_states]["head.logits"] + logits, _ = logits_meta.local_to_global(logits) + logits = logits.unflatten( + 0, (kwargs[AttentionKwargs.batch_dim].global_size, kwargs[AttentionKwargs.sequence_q_dim].global_size) + ) if output_hidden_states: hidden_states = { key: tensor if meta is None else meta.local_to_global(tensor)[0] - for key, (meta, tensor) in kwargs["hidden_states"].items() + for key, (meta, tensor) in kwargs[AttentionKwargs.hidden_states].items() } else: hidden_states = None diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bd2932984..cabcdc489 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -1,3 +1,4 @@ +import functools import logging import re import typing @@ -72,35 +73,29 @@ def preprocess_meta( micro_sequence_length, self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) - hidden_sequence_q_dim = ( - TensorDim( - BlockDimNames.sequence_q_tp, - micro_sequence_length, - self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_sequence_data), + token_dim = TensorDim( + "token", + batch_dim.global_size * sequence_q_dim.global_size, + self._distributed_config.get_distributed_dim(DistributedDimNames.data), + ) + # The token dimension as appears in hidden states, i.e. with possible sequence-tensor-parallel split. + hidden_token_dim = ( + ( + "token_tp", + token_dim.global_size, + self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_data), ) if self._distributed_config.sequence_tensor_parallel - else sequence_q_dim - ) - - need_sequence_first = hidden_sequence_q_dim.size != sequence_length - if self._config.sequence_first is None: - sequence_first = need_sequence_first - else: - sequence_first = self._config.sequence_first - assert not (need_sequence_first and not sequence_first) - - hidden_dims = ( - (hidden_sequence_q_dim, batch_dim, self._hidden_dim) - if sequence_first - else (batch_dim, hidden_sequence_q_dim, self._hidden_dim) + else token_dim ) common_kwargs = { LanguageModelKwargs.phase: phase, - AttentionKwargs.sequence_first: sequence_first, - AttentionKwargs.hidden_dims: hidden_dims, AttentionKwargs.sequence_length: sequence_length, + AttentionKwargs.batch_dim: batch_dim, AttentionKwargs.sequence_q_dim: sequence_q_dim, + AttentionKwargs.token_dim: token_dim, + AttentionKwargs.hidden_token_dim: hidden_token_dim, LanguageModelKwargs.mask_inputs: not truncate_documents, } @@ -122,7 +117,7 @@ def preprocess_meta( sequence_k_dim = TensorDim(BlockDimNames.sequence_k, sequence_k) tokens = TensorMeta.from_dims( - hidden_dims[:2], tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 + (token_dim,), tensor_name=f"tokens_{sequence_k_past}_to_{sequence_k-1}", dtype=torch.int64 ) kwargs = { @@ -131,16 +126,18 @@ def preprocess_meta( } if phase != PhaseType.inference: kwargs[LanguageModelKwargs.labels] = TensorMeta.from_dims( - hidden_dims[:2], tensor_name="labels", dtype=torch.int64 + (token_dim,), tensor_name="labels", dtype=torch.int64 ) reference_kwargs = {} for name, reference_preprocessed_meta in reference_preprocessed_metas.items(): reference_tokens, reference_kwargs_ = reference_preprocessed_meta[i] for key in ( - AttentionKwargs.sequence_first, AttentionKwargs.sequence_length, + AttentionKwargs.batch_dim, AttentionKwargs.sequence_q_dim, AttentionKwargs.sequence_k_dim, + AttentionKwargs.token_dim, + AttentionKwargs.hidden_token_dim, ): Assert.eq(reference_kwargs_[key], kwargs[key]) reference_kwargs[name] = reference_kwargs_ @@ -158,6 +155,7 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + extra_kwargs: dict[str, typing.Any] | None = None, ) -> list[tuple[torch.Tensor, dict]]: # TODO Move batch splitting elsewhere, align interface with LayerBase assert self._is_setup @@ -167,44 +165,18 @@ def preprocess_batch( if preprocessed_meta is None: preprocessed_meta = self.preprocess_meta(batch, phase) - distillation_models = self._config.decoder.get_reference_models() - # TODO: Support multiple distillation models? - assert len(distillation_models) <= 1 - reference_logits = [{} for _ in preprocessed_meta] + reference_preprocessed_batches = {} for name, reference_model in self._reference_models.items(): reference_preprocessed_meta = [ (tokens_meta, kwargs_meta["reference_models"][name]) for tokens_meta, kwargs_meta in preprocessed_meta ] - - # Set output_hidden_states in reference metadata before preprocessing if needed for distillation - if name in distillation_models: - reference_output_hidden_states = [r"decoder\.\d+\.mixer_output$"] - for _, ref_kwargs_meta in reference_preprocessed_meta: - ref_kwargs_meta[BlockKwargs.output_hidden_states] = [ - re.compile(pattern) for pattern in reference_output_hidden_states - ] - - reference_batch = reference_model.fast_llm_model.base_model.preprocess_batch( + reference_preprocessed_batches[name] = reference_model.fast_llm_model.base_model.preprocess_batch( batch, reference_preprocessed_meta, phase=PhaseType.inference, iteration=iteration, ) - # TODO: Do things work with >1? - Assert.eq(len(reference_batch), len(preprocessed_meta), 1) - for i, (reference_tokens, reference_kwargs) in enumerate(reference_batch): - reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - reference_logits[i][f"{name}_logits"] = reference_kwargs["logits"] - if BlockKwargs.hidden_states in reference_kwargs and reference_kwargs[BlockKwargs.hidden_states]: - # Extract activations from hidden_states dict (stored by _debug method) - # Format: {layer_name: (meta, tensor), ...} - activations = { - layer_name: tensor - for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() - } - reference_logits[i][f"{name}_activations"] = activations - preprocessed = [] presents = None for i, (_, kwargs_meta) in enumerate(preprocessed_meta): @@ -217,50 +189,66 @@ def preprocess_batch( pasts = presents presents = None if i == len(preprocessed_meta) - 1 else [] - # Create activation mask for activation distillation - # This mask should: - # - Be 0 on padding tokens (added at the end when documents aren't truncated) - # - Be 1 on image placeholder tokens (token value -100 but not padding) - # - Be 1 on all other valid tokens (ignores loss-masking-spans) - # - # Note: Padding is added as a separate document with all tokens = -100 - # We detect padding by checking if all tokens in a document segment are -100 - activation_mask = torch.ones_like(cropped_tokens.tokens, dtype=torch.bool) - - for sample_index, sample_lengths in enumerate(cropped_tokens.lengths): - # Iterate through documents in this sample - pos = 0 - for doc_length in sample_lengths: - # Check if this document is padding (all tokens are -100) - doc_tokens = cropped_tokens.tokens[sample_index, pos : pos + doc_length] - is_padding_doc = torch.all(doc_tokens == -100).item() - - if is_padding_doc: - # This is a padding document, mask it out - activation_mask[sample_index, pos : pos + doc_length] = False - - pos += doc_length - kwargs: dict[str, typing.Any] = { **kwargs_meta, AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, BlockKwargs.iteration: iteration, AttentionKwargs.sequence_lengths: cropped_tokens.lengths, - BlockKwargs.activation_mask: activation_mask, AttentionKwargs.device: self._distributed.device, + BlockKwargs.output_hidden_states: [], BlockKwargs.hidden_states: {}, - **reference_logits[i], } + if extra_kwargs is not None: + Assert.empty(kwargs.keys() & extra_kwargs.keys()) + kwargs.update(extra_kwargs) + + # TODO: Simplify, check more carefully if needed. + if self._decoder_reference_models: + # Create activation mask for activation distillation + # This mask should: + # - Be 0 on padding tokens (added at the end when documents aren't truncated) + # - Be 1 on image placeholder tokens (token value -100 but not padding) + # - Be 1 on all other valid tokens (ignores loss-masking-spans) + # + # Note: Padding is added as a separate document with all tokens = -100 + # We detect padding by checking if all tokens in a document segment are -100 + activation_mask = torch.ones_like(cropped_tokens.tokens, dtype=torch.bool) + + for sample_index, sample_lengths in enumerate(cropped_tokens.lengths): + # Iterate through documents in this sample + pos = 0 + for doc_length in sample_lengths: + # Check if this document is padding (all tokens are -100) + doc_tokens = cropped_tokens.tokens[sample_index, pos : pos + doc_length] + is_padding_doc = torch.all(doc_tokens == -100).item() + + if is_padding_doc: + # This is a padding document, mask it out + activation_mask[sample_index, pos : pos + doc_length] = False + + pos += doc_length + + kwargs[BlockKwargs.activation_mask] = activation_mask.flatten() + + for name, reference_model in self._reference_models.items(): + reference_tokens, reference_kwargs = reference_preprocessed_batches[name][i] + if name in self._decoder_reference_models: + # TODO: Get the actual names + reference_kwargs[BlockKwargs.output_hidden_states].append( + re.compile(r"decoder\.\d+\.mixer_output$") + ) - # Add activation-distillation targets - assert len(distillation_models) <= 1 - for distillation_model in distillation_models: - teacher_key = f"{distillation_model}_activations" - if teacher_key in reference_logits[i]: - kwargs[BlockKwargs.activation_distillation_targets] = reference_logits[i].pop(teacher_key) + reference_model.forward(reference_tokens, reference_kwargs, iteration=iteration) - if phase != PhaseType.inference: + kwargs[f"reference_{name}_hidden_states"] = { + layer_name: tensor + for layer_name, (meta, tensor) in reference_kwargs[BlockKwargs.hidden_states].items() + } + + if phase == PhaseType.inference: + kwargs[BlockKwargs.output_hidden_states].append(re.compile(r"head\..*logits.*$")) + else: labels_begin = tokens_begin + 1 labels_end = tokens_end + self._config.head.max_prediction_distance labels = batch.tokens.crop(labels_begin, labels_end).tokens @@ -273,18 +261,13 @@ def preprocess_batch( loss_mask[sample_index, begin:end] = False labels = torch.where(loss_mask, labels, -100) + labels = labels.flatten(0, 1) + kwargs[LanguageModelKwargs.labels] = labels + if self._config.head.get_reference_models(): # loss masks only used for distillation currently # loss masks contain all three sources of masking: padding, user-defined spans, image placeholders kwargs[LanguageModelKwargs.loss_mask] = labels >= 0 - kwargs[LanguageModelKwargs.labels] = ( - labels.transpose(0, 1) if kwargs[AttentionKwargs.sequence_first] else labels - ).contiguous() - if LanguageModelKwargs.loss_mask in kwargs and kwargs[AttentionKwargs.sequence_first]: - kwargs[LanguageModelKwargs.loss_mask] = ( - kwargs[LanguageModelKwargs.loss_mask].transpose(0, 1).contiguous() - ) - if batch.chosen_spans is not None: kwargs[LanguageModelKwargs.chosen_spans] = batch.chosen_spans.crop(labels_begin, labels_end).ranges @@ -293,11 +276,7 @@ def preprocess_batch( labels_begin, labels_end ).ranges - tokens = ( - cropped_tokens.tokens.transpose(0, 1) - if kwargs[AttentionKwargs.sequence_first] - else cropped_tokens.tokens - ).contiguous() + tokens = cropped_tokens.tokens.flatten(0, 1) self.preprocess(kwargs) preprocessed.append((tokens, kwargs)) @@ -310,6 +289,19 @@ def get_tied_parameters(self) -> dict[str, tuple[ParameterMeta, tuple[int, ...]] output_weights.insert(0, self.embeddings.word_embeddings_weight) return {output_weights[0].tensor_name: output_weights} if len(output_weights) > 1 else {} + @functools.cached_property + def _decoder_reference_models(self) -> set[str]: + out = self._config.decoder.get_reference_models() + Assert.leq(out, self._reference_models.keys()) + Assert.leq(len(out), 1) + return out + + @functools.cached_property + def _head_reference_models(self) -> set[str]: + out = self._config.head.get_reference_models() + Assert.leq(out, self._reference_models.keys()) + return out + class GPTModel[ConfigType: GPTModelConfig](FastLLMModel[ConfigType]): # TODO: Can we drop class? diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index 890d5760e..e90bd4d89 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -9,7 +9,6 @@ from fast_llm.engine.distributed.config import DistributedDim, DistributedDimNames, PhaseType from fast_llm.engine.inference.runner import InferenceRunner from fast_llm.layers.attention.config import AttentionKwargs -from fast_llm.layers.block.config import BlockDimNames, BlockKwargs from fast_llm.layers.language_model.config import LanguageModelKwargs from fast_llm.layers.vision.config import VisionKwargs from fast_llm.layers.vision.vision_encoder import VisionMultiModalModel @@ -97,54 +96,48 @@ def preprocess_meta( # TODO: What about sequence data? batch_data_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.batch_data) - micro_sequence_length = tokens.global_shape.numel() - - batch_and_sequence_q_dim = PatchSequenceTensorDim( - BlockDimNames.sequence_q, - micro_sequence_length, + token_dim = PatchSequenceTensorDim( + "token", + kwargs[VisionKwargs.token_dim].global_size, self._distributed_config.get_distributed_dim(DistributedDimNames.data), batch_data_dim, ) - hidden_batch_and_sequence_q_dim = ( + hidden_token_dim = ( PatchSequenceTensorDim( - BlockDimNames.sequence_q_tp, - micro_sequence_length, + "token_tp", + token_dim.global_size, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_data), batch_data_dim, ) if self._distributed_config.sequence_tensor_parallel - else batch_and_sequence_q_dim + else token_dim ) # These are used by the model (preprocessing) and shouldn't see the batch-parallel dim. sequence_q_dim = TensorDim( - BlockDimNames.sequence_q, - micro_sequence_length, + "sequence_q", + token_dim.global_size, self._distributed_config.get_distributed_dim(DistributedDimNames.sequence_data), ) - sequence_k_dim = TensorDim(BlockDimNames.sequence_k, micro_sequence_length) + TensorDim("sequence_k", token_dim.global_size) image_patches = TensorMeta.from_dims( ( # We combine the batch and sequence dims to allow for variable sequence lengths. # Gives the same result, assuming we disable cross-image attention (TODO: Enforce) - batch_and_sequence_q_dim, + token_dim, # TODO: Relate to tensor dims in patch convolution. TensorDim("input_channels", self._config.vision_encoder.embeddings.input_channels), TensorDim("patch_height", self._config.vision_encoder.embeddings.patch_height), TensorDim("patch_width", self._config.vision_encoder.embeddings.patch_width), ) ) - # Use vision encoder's internal hidden dim (for embeddings/encoder), not the output dim (for adapter) - hidden_dims = ( - (hidden_batch_and_sequence_q_dim, scalar_dim, self.vision_encoder._vision_hidden_dim) - if (sequence_first := kwargs[LanguageModelKwargs.sequence_first]) - else (scalar_dim, hidden_batch_and_sequence_q_dim, self.vision_encoder._vision_hidden_dim) - ) kwargs[self._vision_encoder_namespace] = { - VisionKwargs.sequence_first: sequence_first, - VisionKwargs.sequence_k_dim: sequence_k_dim, - VisionKwargs.sequence_q_dim: sequence_q_dim, - VisionKwargs.hidden_dims: hidden_dims, + VisionKwargs.sequence_length: kwargs[VisionKwargs.sequence_length], + VisionKwargs.batch_dim: scalar_dim, + VisionKwargs.sequence_q_dim: token_dim, + VisionKwargs.sequence_k_dim: token_dim, + VisionKwargs.token_dim: token_dim, + VisionKwargs.hidden_token_dim: hidden_token_dim, } preprocessed_meta.append((image_patches, kwargs)) @@ -159,9 +152,10 @@ def preprocess_batch( phase: PhaseType, iteration: int, metrics: dict | None = None, + extra_kwargs: dict[str, typing.Any] | None = None, ) -> list[tuple[torch.Tensor, dict]]: preprocessed = super().preprocess_batch( - batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics + batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics, extra_kwargs=extra_kwargs ) # TODO: Support micro-sequences. assert len(preprocessed) == 1, "Micro-sequences not supported for MultiModalModel." @@ -194,22 +188,17 @@ def preprocess_batch( VisionKwargs.sequence_lengths: [cropped_image_patches.lengths + [pad_size]], VisionKwargs.sequence_length: sequence_length, VisionKwargs.device: self._distributed.device, - BlockKwargs.output_hidden_states: kwargs.get(BlockKwargs.output_hidden_states, []), - BlockKwargs.hidden_states: kwargs[BlockKwargs.hidden_states], + VisionKwargs.output_hidden_states: kwargs.get(VisionKwargs.output_hidden_states, []), + VisionKwargs.hidden_states: kwargs[VisionKwargs.hidden_states], } # We need to modify `local_unpadded_size` directly in `preprocessed_meta` since it's the one used by the engine. # Unsafe, but only needed for testing. # TODO: Doesn't work with gradient accumulation (only sees the last value). - hidden_batch_and_sequence_q_dim = kwargs[self._vision_encoder_namespace][VisionKwargs.hidden_dims][ - 0 if kwargs[self._vision_encoder_namespace][VisionKwargs.sequence_first] else 1 - ] - assert isinstance(hidden_batch_and_sequence_q_dim, PatchSequenceTensorDim) PatchSequenceTensorDim.local_unpadded_size = cropped_image_patches.patches.size(0) kwargs[LanguageModelKwargs.embedding_map] = ( - (cropped_image_patches.token_map, cropped_image_patches.sample_map) - if kwargs[LanguageModelKwargs.sequence_first] - else (cropped_image_patches.sample_map, cropped_image_patches.token_map) + cropped_image_patches.sample_map * kwargs[VisionKwargs.sequence_q_dim].size + + cropped_image_patches.token_map ) super().preprocess(kwargs) diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index ee3e0e2e1..b1a922099 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -6,6 +6,7 @@ import torch from fast_llm.engine.config_utils.data_type import DataType +from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.layers.attention.config import AttentionKwargs from fast_llm.layers.language_model.config import LM_HEAD_LOSS_NAME, LanguageModelKwargs from fast_llm.layers.language_model.head import LanguageModelHead @@ -28,7 +29,6 @@ class LMHeadTestConfig: logits_scale_factor: float = 1.0 compute_dtype: DataType = DataType.float32 full_precision_residual: bool = False - sequence_first: bool = False loss_masking: bool = False prediction_heads: int = 1 tied_embedding_weight: bool = False @@ -88,22 +88,15 @@ def get_config(self) -> GPTModelConfig: def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: device = "cuda" if torch.cuda.is_available() else "cpu" input_ = torch.randn( - ( - (SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE) - if self.sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH, HIDDEN_SIZE) - ), + (BATCH_SIZE * SEQUENCE_LENGTH, HIDDEN_SIZE), dtype=(torch.float32 if self.full_precision_residual else self.compute_dtype.torch), device=device, requires_grad=True, ) - label_shape = ( - (SEQUENCE_LENGTH + self.prediction_heads - 1, BATCH_SIZE) - if self.sequence_first - else (BATCH_SIZE, SEQUENCE_LENGTH + self.prediction_heads - 1) - ) + label_shape = (BATCH_SIZE * (SEQUENCE_LENGTH + self.prediction_heads - 1),) kwargs: dict[str, typing.Any] = { - AttentionKwargs.sequence_first: self.sequence_first, + AttentionKwargs.batch_dim: TensorDim("batch", BATCH_SIZE), + AttentionKwargs.sequence_q_dim: TensorDim("sequence_q", SEQUENCE_LENGTH), AttentionKwargs.grad_output: 1.0, } if self.loss_masking: @@ -122,11 +115,13 @@ def get_inputs(self) -> tuple[torch.Tensor, dict[str, typing.Any]]: if self.distillation_loss is not False: assert self.prediction_heads == 1 - kwargs[f"distillation_logits"] = torch.randn( - input_.shape[:-1] + (VOCAB_SIZE,), - dtype=input_.dtype, - device=device, - ) + kwargs[f"reference_distillation_hidden_states"] = { + "head.logits": torch.randn( + input_.shape[:-1] + (VOCAB_SIZE,), + dtype=input_.dtype, + device=device, + ) + } return input_, kwargs def get_reference_outputs( @@ -153,28 +148,25 @@ def get_reference_outputs( losses = {} if self.actual_label_loss is not False: - if self.sequence_first: - labels = kwargs[LanguageModelKwargs.labels][ - head._prediction_distance : head._prediction_distance + logits.size(0) - ] - else: - labels = kwargs[LanguageModelKwargs.labels][ - :, head._prediction_distance : head._prediction_distance + logits.size(1) + labels = ( + kwargs[LanguageModelKwargs.labels] + .view(BATCH_SIZE, (SEQUENCE_LENGTH + self.prediction_heads - 1))[ + :, head._prediction_distance : head._prediction_distance + SEQUENCE_LENGTH ] - label_loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), labels.flatten(), reduction="none" - ).mean() + .flatten() + ) + label_loss = torch.nn.functional.cross_entropy(logits, labels, reduction="none").mean() losses["label"] = label_loss.detach() total_loss = total_loss + float(self.actual_label_loss) * label_loss if self.distillation_loss is not False: distillation_loss = torch.nn.functional.cross_entropy( - logits.flatten(0, -2), - torch.softmax(kwargs[f"distillation_logits"].flatten(0, -2), -1), + logits, + torch.softmax(kwargs[f"reference_distillation_hidden_states"]["head.logits"], -1), reduction="none", ) if LanguageModelKwargs.loss_mask in kwargs: - distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask].flatten() + distillation_loss = distillation_loss * kwargs[LanguageModelKwargs.loss_mask] distillation_loss = distillation_loss.mean() losses["distillation"] = distillation_loss.detach() total_loss = total_loss + float(self.distillation_loss) * distillation_loss @@ -220,7 +212,6 @@ def _add_configs(base_name: str, **kwargs): _add_configs("default") _add_configs("bfloat16", compute_dtype=DataType.bfloat16) _add_configs("full_precision_residual", full_precision_residual=True) -_add_configs("sequence_first", sequence_first=True) _add_configs("logit_scaling", logits_scale_factor=5.0) _add_configs("tied_embedding_weight", tied_embedding_weight=True) _add_configs("multi_token_prediction", prediction_heads=2) @@ -240,7 +231,7 @@ def _add_configs(base_name: str, **kwargs): for _lm_head_test_config in _lm_head_test_configs ], ) -def test_lm_head(test_config): +def test_lm_head(test_config: LMHeadTestConfig): model_config = test_config.get_config() model, distributed = get_base_model(model_config) input_, kwargs = test_config.get_inputs() diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index c12fe52e9..d096b4af3 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -69,9 +69,7 @@ def _compare_mixers( sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] fast_kwargs = { BlockKwargs.device: distributed.device, - BlockKwargs.sequence_first: False, BlockKwargs.sequence_lengths: sequence_lengths, - BlockKwargs.hidden_dims: (HIDDEN_SIZE,), BlockKwargs.sequence_q_dim: TensorDim("", SEQ_LEN), BlockKwargs.sequence_k_dim: TensorDim("", SEQ_LEN), } diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index bc538f9a0..d31cffa50 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -45,8 +45,7 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): """ Check that Gated Delta Net forward/backward match with and without packing. """ - hidden_size = 32 - hidden_dim = TensorDim("hidden", hidden_size) + hidden_dim = TensorDim("hidden", hidden_size := 32) distributed = Distributed( distributed_config := DistributedConfig(compute_dtype=DataType.float16, use_cuda=torch.cuda.is_available()) ) @@ -68,20 +67,19 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): kwargs = { BlockKwargs.device: distributed.device, - BlockKwargs.sequence_first: False, - BlockKwargs.hidden_dims: (hidden_dim,), } kwargs_packed = { **kwargs, BlockKwargs.sequence_lengths: sequence_lengths, BlockKwargs.sequence_length: seq_len, + BlockKwargs.batch_dim: TensorDim("", batch_size), BlockKwargs.sequence_q_dim: TensorDim("", seq_len), BlockKwargs.sequence_k_dim: TensorDim("", seq_len), } mixer.preprocess(kwargs_packed) - out_packed, context = stage.forward(hidden_states, kwargs_packed) + out_packed, context = stage.forward(hidden_states.flatten(0, 1), kwargs_packed) stage.backward(torch.ones_like(out_packed), context) names, parameters = zip(*list(mixer.named_parameters())) @@ -97,14 +95,15 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): **kwargs, BlockKwargs.sequence_lengths: [[seq_len_]], BlockKwargs.sequence_length: seq_len_, + BlockKwargs.batch_dim: TensorDim("", 1), BlockKwargs.sequence_q_dim: TensorDim("", seq_len_), BlockKwargs.sequence_k_dim: TensorDim("", seq_len_), } mixer.preprocess(kwargs_seq) - out, context = stage.forward(seq.unsqueeze(0), kwargs_seq) + out, context = stage.forward(seq, kwargs_seq) stage.backward(torch.ones_like(out), context) out_refs.append(out) - out_ref = torch.cat(out_refs, dim=1).view_as(out_packed) + out_ref = torch.cat(out_refs, dim=0).view_as(out_packed) Assert.rms_close_relative(out_packed, out_ref, 1e-3, 1e-4) diff --git a/tests/test_loss_mask.py b/tests/test_loss_mask.py index 8c131dfa7..cdf2295e0 100644 --- a/tests/test_loss_mask.py +++ b/tests/test_loss_mask.py @@ -220,13 +220,7 @@ def test_all_padding_sample(self): labels = kwargs[LanguageModelKwargs.labels] # Get labels for sample 1 (all should be -100) - # Handle sequence_first dimension ordering - if labels.shape[0] > labels.shape[1]: - # sequence_first=True: shape is (seq, batch) - sample1_labels = labels[:, 1] - else: - # sequence_first=False: shape is (batch, seq) - sample1_labels = labels[1, :] + sample1_labels = labels[8:] assert torch.all(sample1_labels == -100), f"All labels in padding sample should be -100, got {sample1_labels}" diff --git a/tests/utils/distributed_configs.py b/tests/utils/distributed_configs.py index f08e9a488..bd5a92720 100644 --- a/tests/utils/distributed_configs.py +++ b/tests/utils/distributed_configs.py @@ -123,14 +123,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon num_gpus=1, compare_config=_fp16_compare, ), - # Sequence-first baseline - DistributedTestingConfig( - name="sf", - compare="simple", - config_args=["model.base_model.sequence_first=True"], - num_gpus=1, - compare_config=_compare_layer_mismatch, - ), # Cross-entropy splits. DistributedTestingConfig( name="ce4", @@ -171,14 +163,6 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon num_gpus=1, compare_config=_compare_layer_match, ), - # Sequence-first gradient accumulation baseline. - DistributedTestingConfig( - name="df4_sf", - compare="simple", - config_args=["batch.depth_first_micro_batches=4", "model.base_model.sequence_first=True"], - num_gpus=1, - compare_config=_compare_layer_mismatch, - ), ] SINGLE_GPU_TESTING_CONFIGS = {config.name: config for config in _SINGLE_GPU_TESTING_CONFIGS} @@ -221,7 +205,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Sequence-data-parallel DistributedTestingConfig( name="sdp2", - compare="sf", + compare="simple", config_args=["model.distributed.sequence_data_parallel=2"], num_gpus=2, compare_config=_compare_layer_match, @@ -238,7 +222,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Simple sequence-tensor-parallel DistributedTestingConfig( name="stp2", - compare="sf", + compare="simple", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", @@ -260,7 +244,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Cross-entropy splits DistributedTestingConfig( name="stp2_ce4", - compare="sf", + compare="simple", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", @@ -274,7 +258,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Simple DistributedTestingConfig( name="dp2_stp2", - compare="sf", + compare="simple", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", @@ -285,7 +269,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Breadth-first micro-batches DistributedTestingConfig( name="sdp2_stp2_bf4", - compare="df4_sf", + compare="df4", config_args=[ "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", @@ -298,7 +282,7 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon # Sequence-data-parallel DistributedTestingConfig( name="sdp2_stp2", - compare="sf", + compare="simple", config_args=[ "model.distributed.sequence_data_parallel=2", "model.distributed.tensor_parallel=2", @@ -358,10 +342,10 @@ def get_config(relative: float = 0, absolute: float = 0, **kwargs) -> CompareCon compare_config=_compare_layer_match, ), # ===== 2d configs (Tensor + Pipeline) - # Simple [sf, mb] + # Simple [mb] DistributedTestingConfig( name="stp2_pp2s1_bf4", - compare="df4_sf", + compare="df4", config_args=[ "model.distributed.tensor_parallel=2", "model.distributed.sequence_tensor_parallel=True", diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index a4f28d14c..7b41c1f50 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -314,6 +314,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=1.5, ) update_and_add_testing_config( @@ -333,6 +334,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, + compare_factor=1.5, ) update_and_add_testing_config( @@ -360,6 +362,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.unimportant, ModelTestingGroup.distributed: ModelTestingGroupAction.unimportant, }, + compare_factor=1.0, ) del MODEL_CONFIGS["starcoder_2"].config_dict["model"]["base_model"]["embeddings"]["num_position_embeddings"] @@ -394,6 +397,7 @@ def update_and_add_testing_config( ModelTestingGroup.megatron: ModelTestingGroupAction.normal, ModelTestingGroup.distributed: ModelTestingGroupAction.normal, }, + compare_factor=1.0, ) update_and_add_testing_config( From e0d0d7da27300eb7c9a70745c0ec5fdc7567273c Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Feb 2026 13:25:40 -0500 Subject: [PATCH 18/22] cleaanp --- fast_llm/engine/schedule/runner.py | 10 ---------- fast_llm/layers/attention/rotary/rotary.py | 1 - fast_llm/layers/language_model/loss/entropy_loss.py | 1 - 3 files changed, 12 deletions(-) diff --git a/fast_llm/engine/schedule/runner.py b/fast_llm/engine/schedule/runner.py index 2d7e02f77..4a6f3b3cb 100644 --- a/fast_llm/engine/schedule/runner.py +++ b/fast_llm/engine/schedule/runner.py @@ -19,7 +19,6 @@ from fast_llm.engine.optimizer.optimizer import Optimizer from fast_llm.engine.schedule.config import EventType, MockEvent, MockStream, ScheduleConfig, StepType, StreamType from fast_llm.engine.schedule.schedule import Schedule, Step -from fast_llm.layers.block.config import BlockKwargs from fast_llm.logging import log_memory_usage from fast_llm.utils import Assert @@ -406,15 +405,6 @@ def _recv(self, context: BatchContext, step: Step) -> None: self._record_event(context, EventType.compute_wait_pipe, step) def _forward(self, context: BatchContext, step: Step) -> None: - print( - "IASINBUI", - step, - ( - context.batch[step.data_index].get(BlockKwargs.grad_output) - if step.data_index in context.batch - else "PPPPP" - ), - ) output, grad_context = self._stages[step.stage].forward( self._get_forward_input(context, step), context.batch[step.data_index], diff --git a/fast_llm/layers/attention/rotary/rotary.py b/fast_llm/layers/attention/rotary/rotary.py index e24d85e36..307256a72 100644 --- a/fast_llm/layers/attention/rotary/rotary.py +++ b/fast_llm/layers/attention/rotary/rotary.py @@ -235,7 +235,6 @@ def forward( self, query: torch.Tensor, key: torch.Tensor, kwargs: dict[str, typing.Any] ) -> tuple[torch.Tensor, torch.Tensor]: rotary_fn = triton_rotary_autograd_ if TritonConfig.enabled(query.device) else rotary_embeddings_real - print("AAAAA", query.shape, kwargs[AttentionKwargs.rotary_freq_q].shape) query = rotary_fn(query, kwargs[AttentionKwargs.rotary_freq_q]) key = rotary_fn(key, kwargs[AttentionKwargs.rotary_freq_k]) return query, key diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index d7a76dd83..537c7996d 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -53,7 +53,6 @@ def forward_backward( split_index: int = 0, grad_logits: torch.Tensor | None = None, ) -> "tuple[torch.Tensor, torch.Tensor | None]": - print("logits", logits.shape) return ( triton_entropy_loss_forward_backward if TritonConfig.enabled(logits.device, self._config.use_triton) From 99e6400e8acf8e8e30493e0583e12ffe2c62b339 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Feb 2026 14:22:53 -0500 Subject: [PATCH 19/22] fix --- fast_llm/layers/attention/attention.py | 3 --- fast_llm/layers/decoder/block.py | 1 - 2 files changed, 4 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index be8b31f39..859bafea2 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -330,7 +330,6 @@ def _forward( if (past_key_values := kwargs.get(AttentionKwargs.past_key_values)) is not None: # Clear the lists so tensors can be de-allocated - # TODO: ===== Check ===== key_value = torch.cat((past_key_values.pop(0), key_value), dim=1) if (presents := kwargs.get(AttentionKwargs.presents)) is not None: @@ -340,11 +339,9 @@ def _forward( # Manually add the gradients from later micro-sequences. key_value = AttachGrad.apply(key_value, present) - # TODO: ===== Check ===== key_value = key_value[:, : kwargs[AttentionKwargs.sequence_k_dim].size] key, value = key_value.split(self._local_head_groups * self._config.head_size, dim=-1) - # TODO: ===== Expand batch seq dim ===== query = query.view(*query.shape[:2], self._local_heads, self._config.head_size) key = key.view(*key.shape[:2], self._local_head_groups, self._config.head_size) value = value.view(*value.shape[:2], self._local_head_groups, self._config.head_size) diff --git a/fast_llm/layers/decoder/block.py b/fast_llm/layers/decoder/block.py index bb9df3fb9..dd19c1086 100644 --- a/fast_llm/layers/decoder/block.py +++ b/fast_llm/layers/decoder/block.py @@ -179,7 +179,6 @@ def _activation_distillation_loss(self, hidden_states, kwargs, losses, metrics): scaled_activation_loss = self._config.distillation_loss_weight * loss # Backward hook - print(kwargs[BlockKwargs.grad_output]) hidden_states = AuxiliaryLoss.apply(hidden_states, scaled_activation_loss, kwargs.get(BlockKwargs.grad_output)) # Logging From 15c0e4325fe64634e9813a901a6ff8b25591fb4a Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Feb 2026 15:50:27 -0500 Subject: [PATCH 20/22] Simplify MTP --- fast_llm/layers/block/config.py | 14 ++ fast_llm/layers/block/sequence.py | 16 ++- fast_llm/layers/language_model/config.py | 131 ++++-------------- fast_llm/layers/language_model/head.py | 35 ++--- .../layers/language_model/language_model.py | 20 ++- .../language_model/multi_token_prediction.py | 46 +++--- fast_llm/models/gpt/config.py | 4 +- fast_llm/models/gpt/conversion/llama.py | 7 +- fast_llm/models/gpt/conversion/mtp_llama.py | 41 ++---- fast_llm/models/gpt/model.py | 2 +- fast_llm/models/gpt/trainer.py | 2 +- .../models/multimodal/conversion/llava.py | 7 +- tests/layers/test_lm_head.py | 14 +- tests/utils/model_configs.py | 7 +- 14 files changed, 137 insertions(+), 209 deletions(-) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index a1b600445..4f8595250 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -84,6 +84,7 @@ def get_layer( *, lr_scale: float | None, peft: PeftConfig | None, + **kwargs, ) -> "BlockBase": return self.layer_class( self, @@ -91,6 +92,7 @@ def get_layer( hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, + **kwargs, ) def get_reference_models(self) -> set[str]: @@ -106,6 +108,10 @@ def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typi return FixedBlockSequenceConfig._from_dict(default, strict) return super()._from_dict(default, strict=strict) + @property + def last_block_config(self) -> BlockConfig: + raise NotImplementedError() + @config_class(dynamic_type={BlockSequenceConfig: "fixed"}) class FixedBlockSequenceConfig(BlockSequenceConfig): @@ -130,6 +136,10 @@ def layer_class(self) -> "type[FixedBlockSequence]": def get_reference_models(self) -> set[str]: return self.block.get_reference_models() + @property + def last_block_config(self) -> BlockConfig: + return self.block + @config_class(dynamic_type={BlockSequenceConfig: "pattern"}) class PatternBlockSequenceConfig(BlockSequenceConfig): @@ -161,6 +171,10 @@ def _validate(self): super()._validate() + @property + def last_block_config(self) -> BlockConfig: + return self.blocks[self.expanded_pattern[-1]] + @property def layer_class(self) -> "type[PatternBlockSequence]": from fast_llm.layers.block.sequence import PatternBlockSequence diff --git a/fast_llm/layers/block/sequence.py b/fast_llm/layers/block/sequence.py index 54a5b3471..2e7425343 100644 --- a/fast_llm/layers/block/sequence.py +++ b/fast_llm/layers/block/sequence.py @@ -24,6 +24,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_last_layer_input: bool = False, ): super().__init__( config, @@ -40,8 +41,13 @@ def __init__( hidden_dim, lr_scale=self._lr_scale, peft=self._peft, + **( + {"return_input": True} + if return_last_layer_input and block_index == self._config.num_blocks - 1 + else {} + ), ) - for _ in range(self._config.num_blocks) + for block_index in range(self._config.num_blocks) ] ) @@ -75,6 +81,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + return_last_layer_input: bool = False, ): super().__init__( config, @@ -90,8 +97,13 @@ def __init__( hidden_dim, lr_scale=self._lr_scale, peft=self._peft, + **( + {"return_input": True} + if return_last_layer_input and block_index == self._config.num_blocks - 1 + else {} + ), ) - for name in self._config.expanded_pattern + for block_index, name in enumerate(self._config.expanded_pattern) ] ) diff --git a/fast_llm/layers/language_model/config.py b/fast_llm/layers/language_model/config.py index e3446bba6..0e54e7583 100644 --- a/fast_llm/layers/language_model/config.py +++ b/fast_llm/layers/language_model/config.py @@ -1,4 +1,3 @@ -import abc import typing from fast_llm.config import Field, FieldHint, check_field, config_class, skip_valid_if_none @@ -14,7 +13,7 @@ if typing.TYPE_CHECKING: from fast_llm.layers.language_model.embedding import LanguageModelEmbedding - from fast_llm.layers.language_model.head import LanguageModelHead, LanguageModelHeadBase + from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.language_model.language_model import LanguageModel from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction @@ -95,41 +94,8 @@ def layer_class(self) -> "type[LanguageModelEmbedding]": return LanguageModelEmbedding -@config_class(registry=True) -class LanguageModelHeadBaseConfig(BlockConfig): - @classmethod - def _from_dict(cls, default: dict[str, typing.Any], strict: bool = True) -> typing.Self: - if cls is LanguageModelHeadBaseConfig and cls.get_subclass(default.get("type")) is None: - # Default subclass. - return LanguageModelHeadConfig._from_dict(default, strict) - return super()._from_dict(default, strict=strict) - - def get_layer( - self, - distributed_config: DistributedConfig, - embeddings_config: LanguageModelEmbeddingsConfig, - *, - hidden_dim: TensorDim, - lr_scale: float | None, - peft: PeftConfig | None, - ) -> "LanguageModelHeadBase": - return self.layer_class( - self, - distributed_config, - embeddings_config, - hidden_dim=hidden_dim, - lr_scale=combine_lr_scales(lr_scale, self.lr_scale), - peft=peft, - ) - - @property - @abc.abstractmethod - def max_prediction_distance(self) -> int: - pass - - -@config_class(dynamic_type={LanguageModelHeadBaseConfig: "language_model_head"}) -class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): +@config_class() +class LanguageModelHeadConfig(BlockConfig): _abstract = False normalization: NormalizationConfig = Field( desc="Configuration for the final normalization layer.", @@ -160,6 +126,18 @@ class LanguageModelHeadConfig(LanguageModelHeadBaseConfig): hint=FieldHint.feature, valid=check_field(Assert.geq, 0), ) + prediction_heads: int = Field( + default=1, + desc="Prediction heads.", + hint=FieldHint.architecture, + valid=check_field(Assert.gt, 0), + ) + prediction_loss_coefficient: list[float] | None = Field( + default=None, + desc="Loss coefficient for each prediction head.", + doc="If not provided, all heads are equally weighted.", + hint=FieldHint.feature, + ) def get_layer( self, @@ -169,85 +147,36 @@ def get_layer( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, - prediction_distance: int = 0, - prediction_heads: int = 1, - loss_coefficient: float = 1.0, - ): - return self.layer_class( + block_config: DecoderBlockConfig | None = None, + ) -> "tuple[LanguageModelHead, MultiTokenPrediction]": + from fast_llm.layers.language_model.head import LanguageModelHead + from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction + + return LanguageModelHead( + self, + distributed_config, + embeddings_config, + hidden_dim=hidden_dim, + lr_scale=combine_lr_scales(lr_scale, self.lr_scale), + peft=peft, + ), MultiTokenPrediction( self, distributed_config, embeddings_config, hidden_dim=hidden_dim, lr_scale=combine_lr_scales(lr_scale, self.lr_scale), peft=peft, - prediction_distance=prediction_distance, - prediction_heads=prediction_heads, - loss_coefficient=loss_coefficient, + block_config=block_config, ) - @property - def layer_class(self) -> "type[LanguageModelHead]": - from fast_llm.layers.language_model.head import LanguageModelHead - - return LanguageModelHead - def _validate(self) -> None: super()._validate() assert LM_HEAD_LOSS_NAME not in self.losses - @property - def max_prediction_distance(self) -> int: - return 1 - def get_reference_models(self) -> set[str]: return {reference_model for loss in self.losses.values() for reference_model in loss.get_reference_models()} -@config_class(dynamic_type={LanguageModelHeadBaseConfig: "multi_token_prediction"}) -class MultiTokenPredictionConfig(LanguageModelHeadBaseConfig): - _abstract = False - # Needs to be `DecoderBlockConfig` for the `return_input` interface. - # TODO: Make a generic wrapper for returning input instead? - block: DecoderBlockConfig = Field( - desc="Configuration for the decoder block before each head.", - hint=FieldHint.architecture, - ) - # TODO: Generalize? (needs the extra initialization arguments) - head: LanguageModelHeadConfig = Field( - desc="Configuration for the multi-token-prediction heads.", - hint=FieldHint.architecture, - ) - prediction_heads: int = Field( - default=1, - desc="Prediction heads.", - hint=FieldHint.architecture, - valid=check_field(Assert.gt, 0), - ) - prediction_loss_coefficient: list[float] | None = Field( - default=None, - desc="Loss coefficient for each prediction head.", - doc="If not provided, all heads are equally weighted.", - hint=FieldHint.feature, - ) - - def _validate(self) -> None: - super()._validate() - if isinstance(self.prediction_loss_coefficient, list): - Assert.eq(len(self.prediction_loss_coefficient), self.prediction_heads) - for coeff in self.prediction_loss_coefficient: - Assert.geq(coeff, 0) - - @property - def layer_class(self) -> "type[MultiTokenPrediction]": - from fast_llm.layers.language_model.multi_token_prediction import MultiTokenPrediction - - return MultiTokenPrediction - - @property - def max_prediction_distance(self) -> int: - return self.prediction_heads - - @config_class() class LanguageModelConfig(BlockConfig): decoder: BlockSequenceConfig = Field( @@ -258,7 +187,7 @@ class LanguageModelConfig(BlockConfig): hint=FieldHint.architecture, desc="Configuration for the language model embeddings.", ) - head: LanguageModelHeadBaseConfig = Field( + head: LanguageModelHeadConfig = Field( hint=FieldHint.architecture, desc="Configuration for the language model head(s)." ) tied_embedding_weight: bool = Field( diff --git a/fast_llm/layers/language_model/head.py b/fast_llm/layers/language_model/head.py index c5bf9ff9b..85b9bde1d 100644 --- a/fast_llm/layers/language_model/head.py +++ b/fast_llm/layers/language_model/head.py @@ -1,4 +1,3 @@ -import abc import functools import logging import typing @@ -14,12 +13,11 @@ from fast_llm.engine.distributed.config import DistributedConfig, DistributedDimNames from fast_llm.functional.autograd import AuxiliaryLoss, grad_is_context, wrap_forward_backward from fast_llm.functional.linear import output_parallel_linear_backward, output_parallel_linear_forward -from fast_llm.layers.block.block import Block, BlockBase +from fast_llm.layers.block.block import Block from fast_llm.layers.common.peft.config import PeftConfig from fast_llm.layers.language_model.config import ( LM_HEAD_LOSS_NAME, LanguageModelEmbeddingsConfig, - LanguageModelHeadBaseConfig, LanguageModelHeadConfig, LanguageModelKwargs, ) @@ -32,15 +30,7 @@ OUTPUT_WEIGHTS = "output_weights" -class LanguageModelHeadBase[ConfigType: LanguageModelHeadBaseConfig](BlockBase[ConfigType]): - heads: "list[LanguageModelHead]" - - @abc.abstractmethod - def get_output_weights(self) -> list[torch.Tensor]: - pass - - -class LanguageModelHead[ConfigType: LanguageModelHeadConfig](LanguageModelHeadBase[ConfigType], Block): +class LanguageModelHead[ConfigType: LanguageModelHeadConfig](Block[ConfigType]): """ A language model head (GPT), which combines the final layer norm, logits and cross-entropy (if applicable). TODO: Cleanup (dynamic type? composition?) @@ -58,7 +48,6 @@ def __init__( lr_scale: float | None, peft: PeftConfig | None, prediction_distance: int = 0, - prediction_heads: int = 1, loss_coefficient: float = 1.0, ): super().__init__( @@ -68,11 +57,9 @@ def __init__( lr_scale=lr_scale, peft=peft, ) - Assert.in_range(prediction_distance, 0, prediction_heads) + Assert.in_range(prediction_distance, 0, self._config.prediction_heads) self._prediction_distance = prediction_distance - self._prediction_heads = prediction_heads - self._loss_coefficient = loss_coefficient - self._is_last_head = self._prediction_distance == self._prediction_heads - 1 + self._is_last_head = self._prediction_distance == self._config.prediction_heads - 1 self._vocab_parallel = self._distributed_config.tensor_parallel > 1 and embeddings_config.vocab_parallel self._parallel_dim = self._distributed_config.get_distributed_dim(DistributedDimNames.tensor) @@ -99,17 +86,22 @@ def __init__( loss_configs = ( self._config.losses if self._config.losses else {"cross_entropy": LanguageModelLabelEntropyLossConfig()} ) + loss_coefficient = ( + 1.0 + if self._config.prediction_loss_coefficient is None + else self._config.prediction_loss_coefficient[self._prediction_distance] + ) self.losses = torch.nn.ModuleList( [ loss_config.get_layer( distributed_config, self._get_full_loss_name(name), self._prediction_distance, - self._prediction_heads, + self._config.prediction_heads, self._vocab_parallel, self._config.cross_entropy_splits, self._config.logits_scale_factor, - self._loss_coefficient, + loss_coefficient, ) for name, loss_config in loss_configs.items() ] @@ -305,8 +297,3 @@ def _get_full_loss_name(self, name) -> str: @functools.cached_property def _total_loss_name(self) -> str: return self._get_full_loss_name(LM_HEAD_LOSS_NAME) - - @property - def heads(self) -> "list[LanguageModelHead]": - # For compatibility with MTP. - return [self] diff --git a/fast_llm/layers/language_model/language_model.py b/fast_llm/layers/language_model/language_model.py index 385bab7ef..32e2ccbf9 100644 --- a/fast_llm/layers/language_model/language_model.py +++ b/fast_llm/layers/language_model/language_model.py @@ -44,28 +44,42 @@ def __init__( self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft, + **({"return_last_layer_input": True} if self._config.head.prediction_heads > 1 else {}), ) - self.head = self._config.head.get_layer( + self.head, self.multi_token_prediction = self._config.head.get_layer( distributed_config, self._config.embeddings, hidden_dim=self._hidden_dim, lr_scale=self._lr_scale, peft=self._peft, + **( + {"block_config": self._config.decoder.last_block_config} + if self._config.head.prediction_heads > 1 + else {} + ), ) def get_layers(self) -> list[Layer]: - return self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() + layers = self.embeddings.get_layers() + self.decoder.get_layers() + self.head.get_layers() + if self.multi_token_prediction is not None: + layers += self.multi_token_prediction.get_layers() + return layers def preprocess(self, kwargs: dict[str, typing.Any]) -> None: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? self.embeddings.preprocess(kwargs) self.decoder.preprocess(kwargs) self.head.preprocess(kwargs) + if self.multi_token_prediction is not None: + self.multi_token_prediction.preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: # Needed because the base class uses `get_layers` which may bypass the decoder and head. TODO: Avoidable? - return ( + losses = ( self.embeddings.get_loss_definitions(count) + self.decoder.get_loss_definitions(count) + self.head.get_loss_definitions(count) ) + if self.multi_token_prediction is not None: + losses += self.multi_token_prediction.get_loss_definitions(count) + return losses diff --git a/fast_llm/layers/language_model/multi_token_prediction.py b/fast_llm/layers/language_model/multi_token_prediction.py index 5efe2d836..d7665cf00 100644 --- a/fast_llm/layers/language_model/multi_token_prediction.py +++ b/fast_llm/layers/language_model/multi_token_prediction.py @@ -7,12 +7,14 @@ from fast_llm.engine.base_model.config import LossDef from fast_llm.engine.config_utils.tensor_dim import TensorDim from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.block.block import BlockBase from fast_llm.layers.common.peft.config import PeftConfig -from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, MultiTokenPredictionConfig -from fast_llm.layers.language_model.head import LanguageModelHeadBase +from fast_llm.layers.decoder.config import DecoderBlockConfig +from fast_llm.layers.language_model.config import LanguageModelEmbeddingsConfig, LanguageModelHeadConfig +from fast_llm.layers.language_model.head import LanguageModelHead -class MultiTokenPrediction[ConfigType: MultiTokenPredictionConfig](LanguageModelHeadBase[ConfigType]): +class MultiTokenPrediction[ConfigType: LanguageModelHeadConfig](BlockBase[ConfigType]): _config: ConfigType def __init__( @@ -24,6 +26,7 @@ def __init__( hidden_dim: TensorDim, lr_scale: float | None, peft: PeftConfig | None, + block_config: DecoderBlockConfig | None = None, ): super().__init__( config, @@ -32,9 +35,12 @@ def __init__( lr_scale=lr_scale, peft=peft, ) + self._enabled = self._config.prediction_heads > 1 + if self._enabled: + assert block_config is not None self.blocks = torch.nn.ModuleList( [ - self._config.block.get_layer( + block_config.get_layer( self._distributed_config, self._hidden_dim, lr_scale=self._lr_scale, @@ -43,26 +49,21 @@ def __init__( # The previous blocks return a stack of shared_hidden and transformer_output. return_input=index < self._config.prediction_heads - 1, ) - for index in range(self._config.prediction_heads) + for index in range(1, self._config.prediction_heads) ] ) self.heads = torch.nn.ModuleList( [ - self._config.head.get_layer( + LanguageModelHead( + self._config, distributed_config, embeddings_config, hidden_dim=hidden_dim, lr_scale=lr_scale, peft=peft, prediction_distance=index, - prediction_heads=self._config.prediction_heads, - loss_coefficient=( - 1.0 - if self._config.prediction_loss_coefficient is None - else self._config.prediction_loss_coefficient[index] - ), ) - for index in range(self._config.prediction_heads) + for index in range(1, self._config.prediction_heads) ] ) @@ -70,8 +71,11 @@ def __init__( def _layers_with_namespace(self) -> list[Layer]: # Wrap all blocks in a namespace using the unique module name of the first one. # This needs to be in a property because `module_name` is set after `__init__`. - namespace = self.blocks[0].module_name - return [LayerWithNamespace(sublayer, namespace) for layer in self.blocks for sublayer in layer.get_layers()] + return [ + LayerWithNamespace(sublayer, self.blocks[0].module_name) + for layer in self.blocks + for sublayer in layer.get_layers() + ] def get_layers(self) -> list[Layer]: return [ @@ -84,9 +88,13 @@ def get_output_weights(self) -> list[torch.Tensor]: return sum((head.get_output_weights() for head in self.heads), []) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: - self._layers_with_namespace[0].preprocess(kwargs) + if self._enabled: + self._layers_with_namespace[0].preprocess(kwargs) def get_loss_definitions(self, count: int = 1) -> list[LossDef]: - return self.blocks[0].get_loss_definitions(count=count * self._config.prediction_heads) + [ - loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count) - ] + return ( + self.blocks[0].get_loss_definitions(count=count * (self._config.prediction_heads - 1)) + + [loss_definition for head in self.heads for loss_definition in head.get_loss_definitions(count=count)] + if self._enabled + else [] + ) diff --git a/fast_llm/models/gpt/config.py b/fast_llm/models/gpt/config.py index 314741c3b..ddcbcf696 100644 --- a/fast_llm/models/gpt/config.py +++ b/fast_llm/models/gpt/config.py @@ -168,8 +168,8 @@ def _validate(self) -> None: for reference_model in self.reference_models.values(): Assert.geq( - reference_model.model.base_model.head.max_prediction_distance, - self.model.base_model.head.max_prediction_distance, + reference_model.model.base_model.head.prediction_heads, + self.model.base_model.head.prediction_heads, ) Assert.empty(reference_model.model.base_model.get_reference_models()) Assert.eq( diff --git a/fast_llm/models/gpt/conversion/llama.py b/fast_llm/models/gpt/conversion/llama.py index 00d871dbf..983df9869 100644 --- a/fast_llm/models/gpt/conversion/llama.py +++ b/fast_llm/models/gpt/conversion/llama.py @@ -488,16 +488,15 @@ def get_converters( cls, config: LanguageModelHeadConfig, exported_config: dict, - fast_llm_prefix: str, ) -> list[WeightConverter]: return [ *cls.normalization_converter_class.get_converters( config.normalization, - f"{fast_llm_prefix}.final_norm", + f"head.final_norm", f"model.norm", ), get_parameter_converter( - f"{fast_llm_prefix}.output_weights", + f"head.output_weights", "lm_head.weight", drop_on_import=exported_config["tie_word_embeddings"], drop_on_export=exported_config["tie_word_embeddings"], @@ -539,7 +538,7 @@ def get_converters(cls, config: GPTBaseModelConfig, exported_config: dict) -> li return [ *cls.embeddings_converter_class.get_converters(config.embeddings, "embeddings", "model"), *cls.decoder_converter_class.get_converters(config.decoder, "decoder", "model.layers"), - *cls.head_converter_class.get_converters(config.head, exported_config, "head"), + *cls.head_converter_class.get_converters(config.head, exported_config), ] diff --git a/fast_llm/models/gpt/conversion/mtp_llama.py b/fast_llm/models/gpt/conversion/mtp_llama.py index 5b83fed69..0c58b7be5 100644 --- a/fast_llm/models/gpt/conversion/mtp_llama.py +++ b/fast_llm/models/gpt/conversion/mtp_llama.py @@ -5,16 +5,14 @@ from fast_llm.engine.checkpoint.config import CheckpointFormat from fast_llm.engine.checkpoint.external import WeightConverter from fast_llm.layers.block.config import FixedBlockSequenceConfig -from fast_llm.layers.language_model.config import LanguageModelHeadConfig, MultiTokenPredictionConfig +from fast_llm.layers.language_model.config import LanguageModelHeadConfig from fast_llm.models.gpt.config import GPTModelConfig from fast_llm.models.gpt.conversion.config import MTPLlamaCheckpointFormat from fast_llm.models.gpt.conversion.llama import ( LlamaBaseModelConverter, - LlamaBlockConverter, LlamaDecoderConverter, LlamaHeadConverter, LlamaHuggingfaceCheckpointHandler, - get_parameter_converter, ) from fast_llm.utils import Assert, safe_merge_dicts @@ -23,17 +21,14 @@ class MTPLlamaHeadConverter(LlamaHeadConverter): @classmethod def import_config(cls, config: dict) -> dict: return { - "type": "multi_token_prediction", - "block": LlamaBlockConverter.import_config(config), - "head": super().import_config(config), + **super().import_config(config), "prediction_heads": config["prediction_heads"], } @classmethod - def export_config(cls, config: MultiTokenPredictionConfig) -> dict: - Assert.custom(isinstance, config, MultiTokenPredictionConfig) + def export_config(cls, config: LanguageModelHeadConfig) -> dict: return safe_merge_dicts( - super().export_config(config.head), + super().export_config(config), {"prediction_heads": config.prediction_heads}, ) @@ -42,33 +37,15 @@ def get_converters( cls, config: LanguageModelHeadConfig, exported_config: dict, - fast_llm_prefix: str, ) -> list[WeightConverter]: - converters = [] - for prediction_distance in range(config.prediction_heads): - converters += cls.block_converter_class.get_converters( - config.block, - f"{fast_llm_prefix}.blocks.{prediction_distance}", - ( - f"model.layers.{exported_config["num_hidden_layers"]-1}" - if prediction_distance == 0 - else f"model.mtp_heads.{prediction_distance - 1}" - ), - ) - converters += cls.normalization_converter_class.get_converters( + return super().get_converters(config, exported_config) + [ + cls.normalization_converter_class.get_converters( config.head.normalization, - f"{fast_llm_prefix}.heads.{prediction_distance}.final_norm", + f"multi_token_prediction.heads.{prediction_distance - 1}.final_norm", f"model.mtp_norms.{prediction_distance}", ) - converters.append( - get_parameter_converter( - f"{fast_llm_prefix}.heads.0.output_weights", - "lm_head.weight", - drop_on_import=exported_config["tie_word_embeddings"], - ) - ) - - return converters + for prediction_distance in range(1, config.prediction_heads) + ] class MTPLlamaDecoderConverter(LlamaDecoderConverter): diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index cabcdc489..698f624ed 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -250,7 +250,7 @@ def preprocess_batch( kwargs[BlockKwargs.output_hidden_states].append(re.compile(r"head\..*logits.*$")) else: labels_begin = tokens_begin + 1 - labels_end = tokens_end + self._config.head.max_prediction_distance + labels_end = tokens_end + self._config.head.prediction_heads labels = batch.tokens.crop(labels_begin, labels_end).tokens if batch.loss_masking_spans is not None: diff --git a/fast_llm/models/gpt/trainer.py b/fast_llm/models/gpt/trainer.py index ded0f81c8..ef4956176 100644 --- a/fast_llm/models/gpt/trainer.py +++ b/fast_llm/models/gpt/trainer.py @@ -25,7 +25,7 @@ def _get_sampling_parameters( { "sequence_length": self._config.batch.sequence_length, "truncate_documents": self._config.batch.truncate_documents, - "extra_tokens": self._config.model.base_model.head.max_prediction_distance, + "extra_tokens": self._config.model.base_model.head.prediction_heads, } ) return parameters if _return_dict else SamplingParameters(**parameters) diff --git a/fast_llm/models/multimodal/conversion/llava.py b/fast_llm/models/multimodal/conversion/llava.py index 8703ef920..a75d732b8 100644 --- a/fast_llm/models/multimodal/conversion/llava.py +++ b/fast_llm/models/multimodal/conversion/llava.py @@ -258,16 +258,15 @@ def get_converters( cls, config: LanguageModelHeadConfig, exported_config: dict, - fast_llm_prefix: str, ) -> list[WeightConverter]: return [ *cls.normalization_converter_class.get_converters( config.normalization, - f"{fast_llm_prefix}.final_norm", + f"head.final_norm", f"language_model.model.norm", ), get_parameter_converter( - f"{fast_llm_prefix}.output_weights", + f"head.output_weights", "language_model.lm_head.weight", drop_on_import=exported_config["tie_word_embeddings"], ), @@ -320,7 +319,7 @@ def get_converters(cls, config: MultiModalBaseModelConfig, exported_config: dict config.decoder, "decoder", "language_model.model.layers" ), *cls.language_model_converter_class.head_converter_class.get_converters( - config.head, {"tie_word_embeddings": False}, "head" + config.head, {"tie_word_embeddings": False} ), ] diff --git a/tests/layers/test_lm_head.py b/tests/layers/test_lm_head.py index b1a922099..a8ae85c12 100644 --- a/tests/layers/test_lm_head.py +++ b/tests/layers/test_lm_head.py @@ -47,6 +47,7 @@ def get_config(self) -> GPTModelConfig: "normalization": {"type": "rms_norm"}, "logits_scale_factor": self.logits_scale_factor, "cross_entropy_splits": self.num_splits, + "prediction_heads": self.prediction_heads, } losses = {} if self.label_loss is not False: @@ -69,15 +70,7 @@ def get_config(self) -> GPTModelConfig: "base_model": { "decoder": {"num_blocks": 0}, "embeddings": {"vocab_size": VOCAB_SIZE, "full_precision_residual": self.full_precision_residual}, - "head": ( - head_config - if self.prediction_heads == 1 - else { - "type": "multi_token_prediction", - "head": head_config, - "prediction_heads": self.prediction_heads, - } - ), + "head": head_config, "hidden_size": HIDDEN_SIZE, "tied_embedding_weight": self.tied_embedding_weight, }, @@ -246,8 +239,9 @@ def test_lm_head(test_config: LMHeadTestConfig): else None ) - for prediction_distance, head in enumerate(model.head.heads): + for prediction_distance in range(model_config.base_model.head.prediction_heads): # Prepare the LM head + head = model.head if prediction_distance == 0 else model.multi_token_prediction.heads[prediction_distance - 1] Assert.custom(isinstance, head, LanguageModelHead) Assert.eq(head._prediction_distance, prediction_distance) is_duplicate = test_config.tied_embedding_weight or prediction_distance > 0 diff --git a/tests/utils/model_configs.py b/tests/utils/model_configs.py index 7b41c1f50..40dbb7d29 100644 --- a/tests/utils/model_configs.py +++ b/tests/utils/model_configs.py @@ -470,13 +470,8 @@ def update_and_add_testing_config( "llama", "mtp_llama", updates={ - ("model", "base_model", "head"): { - "type": "multi_token_prediction", - "block": _llama_block, - "head": MODEL_CONFIGS["llama"].config_dict["model"]["base_model"]["head"], - "prediction_heads": 2, - }, ("model", "base_model", "decoder", "num_blocks"): 1, + ("model", "base_model", "head", "prediction_heads"): 1, }, # Megatron doesn't support multi-token prediction. megatron_args=None, From f803e822cc65592787e519c6227b373244f36a81 Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Wed, 11 Feb 2026 16:12:34 -0500 Subject: [PATCH 21/22] misc --- fast_llm/layers/language_model/loss/entropy_loss.py | 8 -------- fast_llm/layers/language_model/loss/loss.py | 1 - fast_llm/layers/language_model/loss/z_loss.py | 6 ------ 3 files changed, 15 deletions(-) diff --git a/fast_llm/layers/language_model/loss/entropy_loss.py b/fast_llm/layers/language_model/loss/entropy_loss.py index 537c7996d..e326b9555 100644 --- a/fast_llm/layers/language_model/loss/entropy_loss.py +++ b/fast_llm/layers/language_model/loss/entropy_loss.py @@ -13,9 +13,6 @@ class LanguageModelLabelEntropyLoss[ConfigType: LanguageModelLabelEntropyLossConfig](LanguageModelLoss[ConfigType]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - def forward_backward( self, logits: "torch.Tensor", @@ -41,11 +38,6 @@ def forward_backward( class LanguageModelDistillationLoss[ConfigType: LanguageModelDistillationLossConfig](LanguageModelLoss[ConfigType]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - if self._prediction_distance > 0: - raise NotImplementedError() - def forward_backward( self, logits: "torch.Tensor", diff --git a/fast_llm/layers/language_model/loss/loss.py b/fast_llm/layers/language_model/loss/loss.py index 8dc88c4a1..f1f65ac39 100644 --- a/fast_llm/layers/language_model/loss/loss.py +++ b/fast_llm/layers/language_model/loss/loss.py @@ -102,7 +102,6 @@ def _get_loss_mask(self, kwargs: dict[str, typing.Any], split_index: int = 0): return None if loss_mask is None else self._prepare_target(loss_mask, kwargs, split_index) def _get_reference_model_logits(self, reference_model: str, kwargs: dict[str, typing.Any], split_index: int = 0): - assert self._prediction_distance == 0 Assert.incl( logits_name := self.module_name.rsplit(".", 2)[0] + f".logits", reference_hidden_states := kwargs[f"reference_{reference_model}_hidden_states"], diff --git a/fast_llm/layers/language_model/loss/z_loss.py b/fast_llm/layers/language_model/loss/z_loss.py index c606e2d68..720592c41 100644 --- a/fast_llm/layers/language_model/loss/z_loss.py +++ b/fast_llm/layers/language_model/loss/z_loss.py @@ -10,12 +10,6 @@ class LanguageModelZLoss[ConfigType: LanguageModelZLossConfig](LanguageModelLoss[ConfigType]): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # TODO: Support vocab_parallel - if self._vocab_parallel: - raise NotImplementedError() - def forward_backward( self, logits: "torch.Tensor", From da9751bf81a3ce9a18aff0b2b39f94a7ddff293e Mon Sep 17 00:00:00 2001 From: Joel Lamy-Poirier Date: Sat, 7 Mar 2026 02:51:54 -0500 Subject: [PATCH 22/22] fixes --- fast_llm/layers/attention/attention.py | 4 +- fast_llm/layers/attention/config.py | 2 +- fast_llm/layers/attention/preprocessing.py | 6 +- fast_llm/layers/block/config.py | 2 +- fast_llm/layers/common/linear/convolution.py | 20 +-- fast_llm/layers/language_model/embedding.py | 2 +- fast_llm/layers/ssm/gdn.py | 121 +++++++------------ fast_llm/layers/ssm/kda.py | 107 +++++----------- fast_llm/layers/ssm/mamba.py | 6 +- fast_llm/models/gpt/model.py | 4 +- fast_llm/models/multimodal/model.py | 2 +- tests/layers/test_attention.py | 2 +- tests/layers/test_ssm.py | 65 +++++----- tests/layers/test_varlen.py | 4 +- 14 files changed, 134 insertions(+), 213 deletions(-) diff --git a/fast_llm/layers/attention/attention.py b/fast_llm/layers/attention/attention.py index 859bafea2..40b1f4d23 100644 --- a/fast_llm/layers/attention/attention.py +++ b/fast_llm/layers/attention/attention.py @@ -324,7 +324,7 @@ def _forward( token_dims = (kwargs[AttentionKwargs.batch_dim], kwargs[AttentionKwargs.sequence_q_dim]) token_shape = tuple(dim.size for dim in token_dims) query = query.unflatten(0, token_shape) - key_value = key_value.unflatten(0, token_shape) + key_value = key_value.unflatten(0, (token_shape[0], token_shape[1] * self._sequence_data_parallel_dim.size)) # TODO: Move the rest to function. @@ -457,7 +457,7 @@ def _preprocess_for_backup_attention(self, kwargs: dict[str, typing.Any]) -> Non seq_ids = torch.stack( [ torch.cat([torch.full((x,), i, device=device) for i, x in enumerate(sample_lens)]) - for sample_lens in kwargs[AttentionKwargs.sequence_lengths] + for sample_lens in kwargs[AttentionKwargs.lengths] ] ) document_mask = (seq_ids[:, None, :] == seq_ids[:, :, None])[ diff --git a/fast_llm/layers/attention/config.py b/fast_llm/layers/attention/config.py index 40baf2009..cad1d20e8 100644 --- a/fast_llm/layers/attention/config.py +++ b/fast_llm/layers/attention/config.py @@ -20,7 +20,7 @@ class MixerKwargs(BlockKwargs): cu_seqlens_k = "cu_seqlens_k" max_seqlen_q = "max_seqlen_q" max_seqlen_k = "max_seqlen_k" - seq_idx = "seq_idx" + document_index_q = "document_index_q" position_ids = "position_ids" diff --git a/fast_llm/layers/attention/preprocessing.py b/fast_llm/layers/attention/preprocessing.py index a9d9936c5..fd048cf76 100644 --- a/fast_llm/layers/attention/preprocessing.py +++ b/fast_llm/layers/attention/preprocessing.py @@ -28,9 +28,7 @@ def preprocess_for_varlen( Assert.eq(kwargs[MixerKwargs.sequence_k_dim].global_size, kwargs[MixerKwargs.sequence_q_dim].global_size) sequence_lengths = [ - sequence_length - for sequence_lengths in kwargs[MixerKwargs.sequence_lengths] - for sequence_length in sequence_lengths + sequence_length for sequence_lengths in kwargs[MixerKwargs.lengths] for sequence_length in sequence_lengths ] if return_cu_seqlens: cu_seqlens_q = torch.tensor([0] + sequence_lengths, dtype=torch.int32, device=device).cumsum( @@ -43,7 +41,7 @@ def preprocess_for_varlen( kwargs[MixerKwargs.max_seqlen_q] = max_seqlen_q kwargs[MixerKwargs.max_seqlen_k] = max_seqlen_q if return_seq_idx: - kwargs[MixerKwargs.seq_idx] = torch.cat( + kwargs[MixerKwargs.document_index_q] = torch.cat( [ torch.full((sequence_length,), i, dtype=torch.int32, device=device) for i, sequence_length in enumerate(sequence_lengths) diff --git a/fast_llm/layers/block/config.py b/fast_llm/layers/block/config.py index 4f8595250..729cdd8a2 100644 --- a/fast_llm/layers/block/config.py +++ b/fast_llm/layers/block/config.py @@ -38,7 +38,7 @@ class BlockKwargs: hidden_token_dim = "hidden_token_dim" # TODO: These are confusing sequence_length = "sequence_length" - sequence_lengths = "sequence_lengths" + lengths = "sequence_lengths" # TODO: Belongs elsewhere? grad_output = "grad_output" activation_distillation_targets = "activation_distillation_targets" diff --git a/fast_llm/layers/common/linear/convolution.py b/fast_llm/layers/common/linear/convolution.py index e8b00fb3c..046d55194 100644 --- a/fast_llm/layers/common/linear/convolution.py +++ b/fast_llm/layers/common/linear/convolution.py @@ -34,11 +34,13 @@ def __init__( else self._forward_torch ) - def _forward_torch(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: - if kwargs: - raise NotImplementedError( - f"Arguments {tuple(kwargs)} not implemented for torch implementation of 1d convolution." - ) + def _forward_torch( + self, input_: torch.Tensor, document_index: torch.Tensor | None = None, lengths: list[int] | None = None + ) -> torch.Tensor: + if document_index is not None and lengths is None: + raise ValueError("Torch implementation of CausalConv1d requires lengths.") + if lengths is not None: + return torch.cat([self._forward_torch(x) for x in input_.split(lengths, dim=-1)], dim=-1) return self._activation.activation_fn( torch.nn.functional.conv1d( input_, @@ -49,13 +51,17 @@ def _forward_torch(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: )[..., : input_.size(1)] ) - def _forward_causal_conv1d(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: + def _forward_causal_conv1d( + self, input_: torch.Tensor, document_index: torch.Tensor | None = None, lengths: list[int] | None = None + ) -> torch.Tensor: + if lengths is not None and document_index is None: + raise ValueError("Compiled implementation of CausalConv1d requires document indices.") return _causal_conv1d_fn( input_, self.weight.squeeze(1), self.bias, activation=(None if self._activation == ActivationType.identity else self._activation.value), - **kwargs, + seq_idx=document_index, ) def get_compute_usage(self, input_: TensorMeta, config: ResourceUsageConfig) -> int: diff --git a/fast_llm/layers/language_model/embedding.py b/fast_llm/layers/language_model/embedding.py index c6df8f62b..f1f1dea75 100644 --- a/fast_llm/layers/language_model/embedding.py +++ b/fast_llm/layers/language_model/embedding.py @@ -87,7 +87,7 @@ def _forward( if self._vocab_parallel: token_mask = (token_ids >= self._vocab_start_index) * (token_ids < self._vocab_end_index) masked_input = (token_ids - self._vocab_start_index) * token_mask - embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * token_mask.unsqueeze(2) # noqa + embeddings = torch.embedding(self.word_embeddings_weight, masked_input) * token_mask.unsqueeze(-1) # noqa embeddings = reduce_forward(embeddings, group) # TODO: Input masking of position embeddings inconsistant with non-vocab-parallel if self.position_embeddings_weight is not None: diff --git a/fast_llm/layers/ssm/gdn.py b/fast_llm/layers/ssm/gdn.py index 5e721d424..70c2fda26 100644 --- a/fast_llm/layers/ssm/gdn.py +++ b/fast_llm/layers/ssm/gdn.py @@ -2,8 +2,6 @@ import typing import torch -import torch.nn.functional as F -from einops import rearrange from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ @@ -50,6 +48,7 @@ def torch_chunk_gated_delta_rule( initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=False, + cu_seqlens=None, ): initial_dtype = query.dtype if use_qk_l2norm_in_kernel: @@ -62,11 +61,11 @@ def torch_chunk_gated_delta_rule( batch_size, num_heads, sequence_length, k_head_dim = key.shape v_head_dim = value.shape[-1] pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size - query = F.pad(query, (0, 0, 0, pad_size)) - key = F.pad(key, (0, 0, 0, pad_size)) - value = F.pad(value, (0, 0, 0, pad_size)) - beta = F.pad(beta, (0, pad_size)) - g = F.pad(g, (0, pad_size)) + query = torch.nn.functional.pad(query, (0, 0, 0, pad_size)) + key = torch.nn.functional.pad(key, (0, 0, 0, pad_size)) + value = torch.nn.functional.pad(value, (0, 0, 0, pad_size)) + beta = torch.nn.functional.pad(beta, (0, pad_size)) + g = torch.nn.functional.pad(g, (0, pad_size)) total_sequence_length = sequence_length + pad_size scale = 1 / (query.shape[-1] ** 0.5) query = query * scale @@ -252,36 +251,6 @@ def __init__( ) self.chunk_gated_delta_rule = torch_chunk_gated_delta_rule - if not _causal_conv1d_available: - raise RuntimeError("Gated delta net requires `causal_conv1d`.") - - def fix_query_key_value_ordering(self, mixed_qkvz, mixed_ba): - """ - Derives `query`, `key` and `value` tensors from `mixed_qkvz` and `mixed_ba`. - Replaces fix_query_key_value_ordering from Qwen due to layout differences. - """ - - local_qkv_sizes = ( - self._local_key_heads * self._config.key_head_dim, - self._local_key_heads * self._config.key_head_dim, - self._local_value_heads * self._config.value_head_dim, - self._local_value_heads * self._config.value_head_dim, - ) - query, key, value, z = torch.split(mixed_qkvz, local_qkv_sizes, dim=-1) - query = query.reshape(*query.shape[:-1], self._local_key_heads, self._config.key_head_dim) - key = key.reshape(*key.shape[:-1], self._local_key_heads, self._config.key_head_dim) - value = value.reshape(*value.shape[:-1], self._local_value_heads, self._config.value_head_dim) - z = z.reshape(*z.shape[:-1], self._local_value_heads, self._config.value_head_dim) - - beta, alpha = torch.split( - mixed_ba, - (self._local_value_heads, self._local_value_heads), - dim=-1, - ) - beta = beta.reshape(*beta.shape[:-1], self._local_value_heads) - alpha = alpha.reshape(*alpha.shape[:-1], self._local_value_heads) - return query, key, value, z, beta, alpha - def _forward( self, input_: torch.Tensor, @@ -301,74 +270,72 @@ def _forward( """ # in sequence parallel TP the input here is already scattered across sequence dimension - # TODO: fuse soome of the reshapes into rearranges + # TODO: fuse some of the reshapes into rearranges hidden_states = input_ + # TODO: ====== Merge qkvz and ba ====== projected_states_qkvz = self.in_proj_qkvz(hidden_states) # bs/seq x seq_len/bs x (qkvz) projected_states_ba = self.in_proj_ba(hidden_states) # bs/seq x seq_len/bs x (b a) - batch_size, sequence_length = projected_states_qkvz.shape[:2] + query_key_value, z = torch.split( + projected_states_qkvz, + [ + 2 * self._local_key_heads * self._config.key_head_dim + + self._local_value_heads * self._config.value_head_dim, + self._local_value_heads * self._config.value_head_dim, + ], + dim=-1, + ) - # note: to support var len training (packing) we need to flatten hidden states to batch_size = 1 - # this is does not seem to be required by causal_conv1d_fn, but it it required by chunked_gdn_rule: https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/gated_delta_rule/chunk.py#L299 - # similarly to kimi linear and to SHortCOnv from fla, we pass it flattened tro conv_1d as well, i.e. see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914 - query, key, value, z, beta, alpha = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba + # Move sequence dim to last so the convolution acts on it, add pretend batch dimension. + # sequence, qkv_total -> 1, qkv_total, sequence + query_key_value = query_key_value.unsqueeze(0).transpose(1, 2) + query_key_value = self.convolution( + query_key_value, + document_index=kwargs[MixerKwargs.document_index_q].unsqueeze(0), + lengths=[length for lengths in kwargs[MixerKwargs.lengths] for length in lengths], ) - query, key, value = (x.reshape(x.shape[0], x.shape[1], -1) for x in (query, key, value)) - - mixed_qkv = torch.cat((query, key, value), dim=-1) - mixed_qkv = rearrange(mixed_qkv, "b s ... -> (b s) ...").unsqueeze(0) # 1 s d - mixed_qkv = rearrange(mixed_qkv, "b t d -> b d t") # mixed_qkv.transpose(1, 2) - # conv func. gets sequence dim as last dim, see https://github.com/Dao-AILab/causal-conv1d/blob/22a4577d8ace9d5703daea91a7fb56695492152b/causal_conv1d/causal_conv1d_interface.py#L110 - mixed_qkv = self.convolution(mixed_qkv, seq_idx=kwargs[MixerKwargs.seq_idx].unsqueeze(0)) - mixed_qkv = rearrange(mixed_qkv, "b d t -> b t d") # mixed_qkv.transpose(1, 2) + # 1, qkv_total, sequence -> 1, sequence, qkv_total + query_key_value = query_key_value.transpose(1, 2) query, key, value = torch.split( - mixed_qkv, - ( + query_key_value, + [ self._local_key_heads * self._config.key_head_dim, self._local_key_heads * self._config.key_head_dim, self._local_value_heads * self._config.value_head_dim, - ), + ], dim=-1, ) - query = query.reshape(query.shape[0], query.shape[1], -1, self._config.key_head_dim) - key = key.reshape(key.shape[0], key.shape[1], -1, self._config.key_head_dim) - value = value.reshape(value.shape[0], value.shape[1], -1, self._config.value_head_dim) - beta = beta.sigmoid() - g = -self.A_log.float().exp() * F.softplus(alpha.float() + self.dt_bias) - - beta = rearrange(beta, "b s ... -> (b s) ...").unsqueeze(0) - g = rearrange(g, "b s ... -> (b s) ...").unsqueeze(0) + # 1, sequence, heads, head_dim + query = query.unflatten(-1, (self._local_key_heads, self._config.key_head_dim)) + key = key.unflatten(-1, (self._local_key_heads, self._config.key_head_dim)) + value = value.unflatten(-1, (self._local_value_heads, self._config.value_head_dim)) if self._value_heads_per_key > 1: query = query.repeat_interleave(self._value_heads_per_key, dim=2) key = key.repeat_interleave(self._value_heads_per_key, dim=2) - core_attn_out, _ = self.chunk_gated_delta_rule( + beta, alpha = torch.split(projected_states_ba, [self._local_value_heads, self._local_value_heads], dim=-1) + + out, _ = self.chunk_gated_delta_rule( query, key, value, - g=g, - beta=beta, + g=self._calculate_g(alpha).unsqueeze(0), + beta=beta.sigmoid().unsqueeze(0), initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, cu_seqlens=kwargs[MixerKwargs.cu_seqlens_q], ) + out = out.squeeze(0) + out = self.norm(out, z.reshape_as(out)) + return self.out_proj(out.flatten(-2)) - z_shape_og = z.shape - core_attn_out = rearrange(core_attn_out.squeeze(0), "(b s) ... -> b s ...", b=batch_size, s=sequence_length) - - core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) - z = z.reshape(-1, z.shape[-1]) - core_attn_out = self.norm(core_attn_out, z) - core_attn_out = core_attn_out.reshape(z_shape_og) - core_attn_out = core_attn_out.reshape(core_attn_out.shape[0], core_attn_out.shape[1], -1) - output = self.out_proj(core_attn_out) - - return output + @torch.compile + def _calculate_g(self, alpha: torch.Tensor) -> torch.Tensor: + return -self.A_log.float().exp() * torch.nn.functional.softplus(alpha.float() + self.dt_bias) def preprocess(self, kwargs: dict[str, typing.Any]) -> None: preprocess_for_varlen( diff --git a/fast_llm/layers/ssm/kda.py b/fast_llm/layers/ssm/kda.py index 07ca3a997..608fb5921 100644 --- a/fast_llm/layers/ssm/kda.py +++ b/fast_llm/layers/ssm/kda.py @@ -2,7 +2,6 @@ import typing import torch -from einops import rearrange, repeat from fast_llm.engine.base_model.config import ResourceUsageConfig from fast_llm.engine.config_utils.initialization import LambdaInitializer, init_normal_, init_ones_ @@ -27,16 +26,6 @@ _kda_available = False -def index_first_axis(x, indices): - other_shape = x.shape[1:] - second_dim = other_shape.numel() - return torch.gather( - rearrange(x, "b ... -> b (...)"), - 0, - repeat(indices, "z -> z d", d=second_dim), - ).reshape(-1, *other_shape) - - class KimiDeltaAttention[ConfigType: KimiDeltaAttentionConfig](BlockWithBias[ConfigType]): """ Implementation of the Kimi Delta Attention mixer. @@ -200,24 +189,6 @@ def __init__( peft=self._peft, ) - def _apply_conv(self, tensor: torch.Tensor, conv: torch.nn.Module, seq_idx: torch.Tensor = None) -> torch.Tensor: - """ - Applies convolution. - Note, in the reference code they use Short Convolution from flash-linear-attention/fla/modules/convolution.py, but that one just uses causal_conv1d anyways. - Varlen: - - seq. idx are only suppored in channel last layout, i.e. no transpose - """ - tensor = rearrange(tensor, "b t d -> b d t") - # tensor = tensor.transpose(1, 2).contiguous() if seq_idx is None else tensor.transpose(1, 2) - tensor = conv(tensor, seq_idx=seq_idx) - return tensor.transpose(1, 2).contiguous() - - def _reshape_heads(self, tensor: torch.Tensor) -> torch.Tensor: - tensor = tensor.contiguous() - # since head_dim is the same vor k,q and v - # same as rearrange(v, '... (h d) -> ... h d', d=self.head_dim) - return tensor.view(tensor.shape[0], tensor.shape[1], self._local_heads, self._config.head_dim) - def _forward( self, input_: torch.Tensor, @@ -226,66 +197,54 @@ def _forward( metrics: dict[str, typing.Any] | None = None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ - Same as in gdn, the idea is to always do forward pass in a packed way, whcih is required for varlen support. + Same as in gdn, the idea is to always do forward pass in a packed way, which is required for varlen support. """ - hidden_states = input_ - - # TODO: can be made more efficeint by rearranging hidden states directly and only once - residual_dtype = hidden_states.dtype - - q = self.q_proj(hidden_states) - k = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - batch_size, sequence_length, _ = q.size() - q = rearrange(q, "b s ... -> (b s) ...").unsqueeze(0) - k = rearrange(k, "b s ... -> (b s) ...").unsqueeze(0) - v = rearrange(v, "b s ... -> (b s) ...").unsqueeze(0) - # because we use cu_seqlens, chunk_kda requires batch size to be 1 (flatten, see https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/ops/kda/chunk.py#L303) - # similarly to ShortConvolution from fla we already operate on flattened batches here (https://github.com/fla-org/flash-linear-attention/blob/71260ecd573cfaaa94305b726465143199e99734/fla/modules/convolution.py#L914) - seq_idx = kwargs[MixerKwargs.seq_idx].unsqueeze(0) - q = self._apply_conv(q, self.q_conv, seq_idx) - k = self._apply_conv(k, self.k_conv, seq_idx) - v = self._apply_conv(v, self.v_conv, seq_idx) - - g_kernel = self.f_b_proj(self.f_a_proj(hidden_states)) - g_kernel = self._reshape_heads(g_kernel) - g_kernel = rearrange(g_kernel, "b s ... -> (b s) ...").unsqueeze(0) + # TODO: ===== Merge q,k,v into a single tensor ====== + q = self.q_proj(input_) + k = self.k_proj(input_) + v = self.v_proj(input_) + + document_index = kwargs[MixerKwargs.document_index_q].unsqueeze(0) + lengths = [length for lengths in kwargs[MixerKwargs.lengths] for length in lengths] + # Move sequence dim to last so the convolution acts on it, add pretend batch dimension. + q = ( + self.q_conv(q.unsqueeze(0).transpose(1, 2), document_index=document_index, lengths=lengths) + .transpose(1, 2) + .unflatten(-1, (self._local_heads, self._config.head_dim)) + ) + k = ( + self.k_conv(k.unsqueeze(0).transpose(1, 2), document_index=document_index, lengths=lengths) + .transpose(1, 2) + .unflatten(-1, (self._local_heads, self._config.head_dim)) + ) + v = ( + self.v_conv(v.unsqueeze(0).transpose(1, 2), document_index=document_index, lengths=lengths) + .transpose(1, 2) + .unflatten(-1, (self._local_heads, self._config.head_dim)) + ) + g_kernel = ( + self.f_b_proj(self.f_a_proj(input_)).unsqueeze(0).unflatten(-1, (self._local_heads, self._config.head_dim)) + ) g_kernel = fused_kda_gate(g_kernel, self.A_log.float(), dt_bias=self.dt_bias) - beta = torch.sigmoid(self.beta_proj(hidden_states).float()) - q = self._reshape_heads(q) - k = self._reshape_heads(k) - v = self._reshape_heads(v) - beta = rearrange(beta, "b s h -> (b s) h").unsqueeze(0) - - # need to install nightly triton for this to work on H100, see https://github.com/fla-org/flash-linear-attention/blob/main/FAQs.md - # cu_seqlens requires batch ssize to be 1, i.e. flattened bacthes - attn_out, _ = chunk_kda( + out, _ = chunk_kda( q=q, k=k, v=v, g=g_kernel, - beta=beta, + beta=torch.sigmoid(self.beta_proj(input_).float()).unsqueeze(0), initial_state=None, output_final_state=False, use_qk_l2norm_in_kernel=True, cu_seqlens=kwargs[MixerKwargs.cu_seqlens_q], ) + out = out.to(input_.dtype).squeeze(0) - attn_out = attn_out.to(residual_dtype) - - g_out = self.g_b_proj(self.g_a_proj(hidden_states)) # bs x seq x n_local_heads x head dim - g_out = self._reshape_heads(g_out) - - attn_out = rearrange(attn_out.squeeze(0), "(b s) h d -> b s h d", b=batch_size, s=sequence_length) - attn_out = self.norm(attn_out, g_out) - attn_out = rearrange(attn_out, "b s h d -> b s (h d)") - attn_out = self.o_proj(attn_out) - - return attn_out + g_out = self.g_b_proj(self.g_a_proj(input_)) + out = self.norm(out, g_out.view_as(out)) + return self.o_proj(out.flatten(-2)) def get_compute_usage(self, input_: TensorMeta, kwargs: dict[str, typing.Any], config: ResourceUsageConfig) -> int: raise NotImplementedError() diff --git a/fast_llm/layers/ssm/mamba.py b/fast_llm/layers/ssm/mamba.py index fd6255e6c..8a7ae2805 100644 --- a/fast_llm/layers/ssm/mamba.py +++ b/fast_llm/layers/ssm/mamba.py @@ -167,7 +167,7 @@ def _forward( assert _mamba_available sequence_length = kwargs[BlockKwargs.sequence_q_dim].size - token_shape = (kwargs[BlockKwargs.batch_dim].size, kwargs[BlockKwargs.sequence_q_dim].size) + token_shape = (div(input_.size(0), sequence_length), sequence_length) # inner_projection : (local_tokens, hidden) -> (batch, sequence, local_inner_projection) inner_projection = self.in_proj(input_).unflatten(0, token_shape) dt = self.dt_proj(self.dt_in_proj(input_)).unflatten(0, token_shape) @@ -184,7 +184,9 @@ def _forward( # x: (batch, sequence, local_head_groups * state) -> (batch, local_heads * state, sequence) x = x.transpose(1, 2) convolution_kwargs = ( - {} if self._config.cross_document_attention else {"seq_idx": kwargs[MixerKwargs.seq_idx].unsqueeze(0)} + {} + if self._config.cross_document_attention + else {"seq_idx": kwargs[MixerKwargs.document_index_q].unsqueeze(0)} ) if self._config.repeat_kv_before_conv: x = self.convolution( diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index 698f624ed..7a6f7ffac 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -80,7 +80,7 @@ def preprocess_meta( ) # The token dimension as appears in hidden states, i.e. with possible sequence-tensor-parallel split. hidden_token_dim = ( - ( + TensorDim( "token_tp", token_dim.global_size, self._distributed_config.get_distributed_dim(DistributedDimNames.tensor_and_data), @@ -194,7 +194,7 @@ def preprocess_batch( AttentionKwargs.past_key_values: pasts, AttentionKwargs.presents: presents, BlockKwargs.iteration: iteration, - AttentionKwargs.sequence_lengths: cropped_tokens.lengths, + AttentionKwargs.lengths: cropped_tokens.lengths, AttentionKwargs.device: self._distributed.device, BlockKwargs.output_hidden_states: [], BlockKwargs.hidden_states: {}, diff --git a/fast_llm/models/multimodal/model.py b/fast_llm/models/multimodal/model.py index e90bd4d89..dab9c8027 100644 --- a/fast_llm/models/multimodal/model.py +++ b/fast_llm/models/multimodal/model.py @@ -185,7 +185,7 @@ def preprocess_batch( kwargs[self._vision_encoder_namespace] = { **kwargs[self._vision_encoder_namespace], VisionKwargs.patch_positions: positions, - VisionKwargs.sequence_lengths: [cropped_image_patches.lengths + [pad_size]], + VisionKwargs.lengths: [cropped_image_patches.lengths + [pad_size]], VisionKwargs.sequence_length: sequence_length, VisionKwargs.device: self._distributed.device, VisionKwargs.output_hidden_states: kwargs.get(VisionKwargs.output_hidden_states, []), diff --git a/tests/layers/test_attention.py b/tests/layers/test_attention.py index 924c2cc7f..e71441015 100644 --- a/tests/layers/test_attention.py +++ b/tests/layers/test_attention.py @@ -35,7 +35,7 @@ def test_attention_implementations(cross_document_attention: bool, causal: bool, kwargs = { AttentionKwargs.device: device, AttentionKwargs.sequence_length: 100, - AttentionKwargs.sequence_lengths: [ + AttentionKwargs.lengths: [ [20, 32, 10, 11, 9, 18], [100], [2, 8, 22, 7, 6, 5, 1, 10, 4, 11, 3, 8, 4, 9], diff --git a/tests/layers/test_ssm.py b/tests/layers/test_ssm.py index d096b4af3..c281de0d3 100644 --- a/tests/layers/test_ssm.py +++ b/tests/layers/test_ssm.py @@ -18,13 +18,16 @@ Apriel2GatedDeltaNet, Apriel2Mamba, KimiDeltaAttention, + is_fast_path_available, ) except ImportError: Apriel2GatedDeltaNet = None Apriel2Mamba = None + is_fast_path_available = False HIDDEN_SIZE = 16 -SEQ_LEN = 65 +SEQUENCE_LENGTH = 65 +BATCH_SIZE = 2 def _compare_mixers( @@ -53,8 +56,8 @@ def _compare_mixers( Assert.rms_close_relative(fast_param, hf_param.view_as(fast_param), threshold, 1e-5, msg=name) hidden_states = torch.randn( - 2, - SEQ_LEN, + BATCH_SIZE, + SEQUENCE_LENGTH, HIDDEN_SIZE, device=distributed.device, dtype=distributed_config.compute_dtype.torch, @@ -66,37 +69,31 @@ def _compare_mixers( if isinstance(hf_out, tuple): (hf_out,) = hf_out - sequence_lengths = [[SEQ_LEN] for _ in range(hidden_states.size(0))] + sequence_lengths = [[SEQUENCE_LENGTH] for _ in range(hidden_states.size(0))] fast_kwargs = { BlockKwargs.device: distributed.device, - BlockKwargs.sequence_lengths: sequence_lengths, - BlockKwargs.sequence_q_dim: TensorDim("", SEQ_LEN), - BlockKwargs.sequence_k_dim: TensorDim("", SEQ_LEN), + BlockKwargs.lengths: sequence_lengths, + BlockKwargs.sequence_q_dim: TensorDim("", SEQUENCE_LENGTH), + BlockKwargs.sequence_k_dim: TensorDim("", SEQUENCE_LENGTH), } fast_llm_layer.train() fast_llm_layer.preprocess(fast_kwargs) - fast_out = fast_llm_layer(hidden_states, fast_kwargs) + fast_out = fast_llm_layer(hidden_states.flatten(0, 1), fast_kwargs).view_as(hidden_states) Assert.rms_close_relative(fast_out, hf_out, threshold, 1e-5) @pytest.mark.slow # Arguments ('seq_idx',) not implemented for torch implementation of 1d convolution. -@pytest.mark.skipif(not transformers.utils.import_utils.is_causal_conv1d_available(), reason="GDN deps missing") +@pytest.mark.skipif(not is_fast_path_available, reason="GDN deps missing") def test_gdn(testing_device): dtype = torch.bfloat16 - - NUM_V_HEADS = 4 - NUM_K_HEADS = 2 - HEAD_DIM = 4 - KERNEL_SIZE = 4 - config_common = { - "value_heads": NUM_V_HEADS, - "key_heads": NUM_K_HEADS, - "key_head_dim": HEAD_DIM, - "value_head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + "value_heads": 4, + "key_heads": 2, + "key_head_dim": 4, + "value_head_dim": 4, + "convolution_layer": {"kernel_size": 4, "activation": "silu"}, } hf_layer = ( @@ -111,14 +108,10 @@ def test_gdn(testing_device): @pytest.mark.slow @pytest.mark.skipif(not _kda_available, reason="KDA fused kernels not available") def test_kda(): - NUM_HEADS = 4 - HEAD_DIM = 4 - KERNEL_SIZE = 4 - kda_config = { - "heads": NUM_HEADS, - "head_dim": HEAD_DIM, - "convolution_layer": {"kernel_size": KERNEL_SIZE, "activation": "silu"}, + "heads": 4, + "head_dim": 4, + "convolution_layer": {"kernel_size": 4, "activation": "silu"}, "normalization": {"epsilon": 1e-5, "activation": "sigmoid"}, } @@ -130,21 +123,17 @@ def test_kda(): @pytest.mark.slow +@pytest.mark.skip("Mamba is broken") @pytest.mark.parametrize("add_linear_biases", [True, False]) @pytest.mark.parametrize("repeat_kv_before_conv", [True, False]) @pytest.mark.skipif(not transformers.utils.import_utils.is_mamba_ssm_available(), reason="Mamba not available") def test_mamba(add_linear_biases, repeat_kv_before_conv): - D_INNER = 128 - D_XB = 64 - D_STATE = 16 - D_CONV = 4 - DT_RANK = 4 config_common = { - "d_inner": D_INNER, - "d_xb": D_XB, - "state_size": D_STATE, - "dt_rank": DT_RANK, + "d_inner": 128, + "d_xb": 64, + "state_size": 16, + "dt_rank": 4, "repeat_kv_before_conv": repeat_kv_before_conv, "add_linear_biases": add_linear_biases, } @@ -152,13 +141,13 @@ def test_mamba(add_linear_biases, repeat_kv_before_conv): mamba_config = { "conv_bias": add_linear_biases, "dt_proj_bias": add_linear_biases, - **config_common, + "d_conv": 4**config_common, } hf_layer = Apriel2Mamba(HIDDEN_SIZE, mamba_config, layer_idx=0) # Create Fast-LLM Mamba layer fast_llm_config = MambaConfig( - convolution_layer={"kernel_size": D_CONV}, + convolution_layer={"kernel_size": 4}, **config_common, ) diff --git a/tests/layers/test_varlen.py b/tests/layers/test_varlen.py index d31cffa50..350259375 100644 --- a/tests/layers/test_varlen.py +++ b/tests/layers/test_varlen.py @@ -71,7 +71,7 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): kwargs_packed = { **kwargs, - BlockKwargs.sequence_lengths: sequence_lengths, + BlockKwargs.lengths: sequence_lengths, BlockKwargs.sequence_length: seq_len, BlockKwargs.batch_dim: TensorDim("", batch_size), BlockKwargs.sequence_q_dim: TensorDim("", seq_len), @@ -93,7 +93,7 @@ def test_mixer_varlen_stacking_equivalence(config: MixerConfig): seq_len_ = len(seq) kwargs_seq = { **kwargs, - BlockKwargs.sequence_lengths: [[seq_len_]], + BlockKwargs.lengths: [[seq_len_]], BlockKwargs.sequence_length: seq_len_, BlockKwargs.batch_dim: TensorDim("", 1), BlockKwargs.sequence_q_dim: TensorDim("", seq_len_),