Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def build(self, device: torch.device | None = None, dtype: torch.dtype | None =
model_paths = list(self.model_path) if isinstance(self.model_path, tuple) else [self.model_path]
model_state_dict = self.load_sd(model_paths, sd_ops=self.model_sd_ops, registry=self.registry, device=device)

# Extract per-tensor FP8 weight scales before load_state_dict drops them.
# Pre-quantized FP8 checkpoints include weight_scale tensors that are not
# part of the model's parameter schema, so strict=False silently discards
# them. We stash them on the model so that fp8_cast's upcast forward can
# apply them during inference dequantization.
_fp8_weight_scales: dict[str, float] = {}
for key, value in model_state_dict.sd.items():
if key.endswith(".weight_scale") and value.numel() == 1:
base_name = key[: -len(".weight_scale")]
_fp8_weight_scales[base_name] = value.item()
if _fp8_weight_scales:
meta_model._fp8_weight_scales = _fp8_weight_scales # type: ignore[attr-defined]

lora_strengths = [lora.strength for lora in self.loras]
if not lora_strengths or (min(lora_strengths) == 0 and max(lora_strengths) == 0):
sd = model_state_dict.sd
Expand Down
53 changes: 46 additions & 7 deletions packages/ltx-core/src/ltx_core/quantization/fp8_cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,27 @@ def _upcast_and_round(
return _fused_add_round_launch(torch.zeros_like(weight, dtype=dtype), weight, seed)


def _replace_fwd_with_upcast(layer: torch.nn.Linear, with_stochastic_rounding: bool = False, seed: int = 0) -> None:
def _replace_fwd_with_upcast(
layer: torch.nn.Linear,
with_stochastic_rounding: bool = False,
seed: int = 0,
weight_scale: float | None = None,
) -> None:
"""
Replace linear.forward and rms_norm.forward with a version that:
Replace linear.forward with a version that:
- upcasts weight and bias to input's dtype
- returns F.linear or F.rms_norm calculated in that dtype
- applies weight_scale if the checkpoint was quantized with per-tensor scaling
- returns F.linear calculated in that dtype

Args:
layer: The Linear layer to patch.
with_stochastic_rounding: Whether to use stochastic rounding during upcast.
seed: Seed for stochastic rounding.
weight_scale: Per-tensor scale factor from the FP8 checkpoint. When provided,
the dequantized weight is multiplied by this value. This is required for
FP8 checkpoints that were quantized with per-tensor scaling (e.g.
``ltx-2.3-22b-dev-fp8.safetensors``) where each weight tensor has an
associated ``weight_scale`` stored alongside it.
"""

layer.original_forward = layer.forward
Expand All @@ -78,6 +94,13 @@ def new_linear_forward(*args, **_kwargs) -> torch.Tensor:
# assume first arg is the input tensor
x = args[0]
w_up = _upcast_and_round(layer.weight, x.dtype, with_stochastic_rounding, seed)

# Apply per-tensor weight scale from FP8 checkpoint if available.
# Without this, pre-quantized FP8 checkpoints produce incorrect outputs
# because the raw FP8 values are not in the correct magnitude range.
if weight_scale is not None:
w_up = w_up * weight_scale

b_up = None

if layer.bias is not None:
Expand All @@ -92,12 +115,28 @@ def _amend_forward_with_upcast(
model: torch.nn.Module, with_stochastic_rounding: bool = False, seed: int = 0
) -> torch.nn.Module:
"""
Replace the forward method of the model's Linear and RMSNorm layers to forward
Replace the forward method of the model's Linear layers to forward
with upcast and optional stochastic rounding.

If the model was loaded from a pre-quantized FP8 checkpoint that includes
per-tensor ``weight_scale`` values (stashed on the model by the builder as
``_fp8_weight_scales``), those scales are automatically applied during the
upcast to produce correctly-scaled outputs.

This is necessary because pre-quantized FP8 checkpoints (e.g.
``ltx-2.3-22b-dev-fp8.safetensors``) store weights in a scaled FP8 format
where the raw FP8 values must be multiplied by their associated
``weight_scale`` to recover the correct magnitude. Without this, a naive
``.to(bfloat16)`` produces values that are orders of magnitude too large,
resulting in noise output.
"""
for m in model.modules():
if isinstance(m, (torch.nn.Linear)):
_replace_fwd_with_upcast(m, with_stochastic_rounding, seed)
# Retrieve per-tensor weight scales stashed by the model builder
weight_scales: dict[str, float] = getattr(model, "_fp8_weight_scales", {})

for name, m in model.named_modules():
if isinstance(m, torch.nn.Linear):
scale = weight_scales.get(name, None)
_replace_fwd_with_upcast(m, with_stochastic_rounding, seed, weight_scale=scale)
return model


Expand Down