Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,7 +615,9 @@
"LongCatImageEditPipeline",
"LongCatImagePipeline",
"LTX2ConditionPipeline",
"LTX2HDRPipeline",
"LTX2ImageToVideoPipeline",
"LTX2InContextPipeline",
"LTX2LatentUpsamplePipeline",
"LTX2Pipeline",
"LTXConditionPipeline",
Expand Down Expand Up @@ -1407,7 +1409,9 @@
LongCatImageEditPipeline,
LongCatImagePipeline,
LTX2ConditionPipeline,
LTX2HDRPipeline,
LTX2ImageToVideoPipeline,
LTX2InContextPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
LTXConditionPipeline,
Expand Down
15 changes: 13 additions & 2 deletions src/diffusers/models/transformers/transformer_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,7 @@ def forward(
perturbation_mask: torch.Tensor | None = None,
use_cross_timestep: bool = False,
attention_kwargs: dict[str, Any] | None = None,
video_self_attention_mask: torch.Tensor | None = None,
return_dict: bool = True,
) -> torch.Tensor:
"""
Expand Down Expand Up @@ -1408,6 +1409,11 @@ def forward(
`False` is the legacy LTX-2.0 behavior.
attention_kwargs (`dict[str, Any]`, *optional*):
Optional dict of keyword args to be passed to the attention processor.
video_self_attention_mask (`torch.Tensor`, *optional*):
Optional multiplicative self-attention mask of shape `(batch_size, num_video_tokens, num_video_tokens)`
applied to the video self-attention in each transformer block. Values in `[0, 1]` where `1` means full
attention and `0` means masked. Used e.g. by the IC-LoRA pipeline to control attention strength between
noisy tokens and appended reference tokens. Audio self-attention is not affected.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a dict-like structured output of type `AudioVisualModelOutput` or a tuple.

Expand All @@ -1430,6 +1436,11 @@ def forward(
audio_encoder_attention_mask = (1 - audio_encoder_attention_mask.to(audio_hidden_states.dtype)) * -10000.0
audio_encoder_attention_mask = audio_encoder_attention_mask.unsqueeze(1)

# Convert video_self_attention_mask from multiplicative mask ([0, 1]) to additive bias form (0 / -10000)
# matching the encoder_attention_mask convention above. Shape is preserved: (B, T_v, T_v).
if video_self_attention_mask is not None:
video_self_attention_mask = (1 - video_self_attention_mask.to(hidden_states.dtype)) * -10000.0

batch_size = hidden_states.size(0)

# 1. Prepare RoPE positional embeddings
Expand Down Expand Up @@ -1569,7 +1580,7 @@ def forward(
audio_cross_attn_rotary_emb,
encoder_attention_mask,
audio_encoder_attention_mask,
None, # self_attention_mask
video_self_attention_mask, # self_attention_mask (video-only)
None, # audio_self_attention_mask
None, # a2v_cross_attention_mask
None, # v2a_cross_attention_mask
Expand Down Expand Up @@ -1598,7 +1609,7 @@ def forward(
ca_audio_rotary_emb=audio_cross_attn_rotary_emb,
encoder_attention_mask=encoder_attention_mask,
audio_encoder_attention_mask=audio_encoder_attention_mask,
self_attention_mask=None,
self_attention_mask=video_self_attention_mask,
audio_self_attention_mask=None,
a2v_cross_attention_mask=None,
v2a_cross_attention_mask=None,
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,8 @@
_import_structure["ltx2"] = [
"LTX2Pipeline",
"LTX2ConditionPipeline",
"LTX2HDRPipeline",
"LTX2InContextPipeline",
"LTX2ImageToVideoPipeline",
"LTX2LatentUpsamplePipeline",
]
Expand Down Expand Up @@ -780,7 +782,14 @@
LTXLatentUpsamplePipeline,
LTXPipeline,
)
from .ltx2 import LTX2ConditionPipeline, LTX2ImageToVideoPipeline, LTX2LatentUpsamplePipeline, LTX2Pipeline
from .ltx2 import (
LTX2ConditionPipeline,
LTX2HDRPipeline,
LTX2ImageToVideoPipeline,
LTX2InContextPipeline,
LTX2LatentUpsamplePipeline,
LTX2Pipeline,
)
from .lucy import LucyEditPipeline
from .lumina import LuminaPipeline, LuminaText2ImgPipeline
from .lumina2 import Lumina2Pipeline, Lumina2Text2ImgPipeline
Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/pipelines/ltx2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["connectors"] = ["LTX2TextConnectors"]
_import_structure["image_processor"] = ["LTX2VideoHDRProcessor"]
_import_structure["latent_upsampler"] = ["LTX2LatentUpsamplerModel"]
_import_structure["pipeline_ltx2"] = ["LTX2Pipeline"]
_import_structure["pipeline_ltx2_condition"] = ["LTX2ConditionPipeline"]
_import_structure["pipeline_ltx2_hdr_lora"] = ["LTX2HDRPipeline", "LTX2HDRReferenceCondition"]
_import_structure["pipeline_ltx2_ic_lora"] = ["LTX2InContextPipeline", "LTX2ReferenceCondition"]
_import_structure["pipeline_ltx2_image2video"] = ["LTX2ImageToVideoPipeline"]
_import_structure["pipeline_ltx2_latent_upsample"] = ["LTX2LatentUpsamplePipeline"]
_import_structure["vocoder"] = ["LTX2Vocoder", "LTX2VocoderWithBWE"]
Expand All @@ -39,9 +42,12 @@
from ...utils.dummy_torch_and_transformers_objects import *
else:
from .connectors import LTX2TextConnectors
from .image_processor import LTX2VideoHDRProcessor
from .latent_upsampler import LTX2LatentUpsamplerModel
from .pipeline_ltx2 import LTX2Pipeline
from .pipeline_ltx2_condition import LTX2ConditionPipeline
from .pipeline_ltx2_hdr_lora import LTX2HDRPipeline, LTX2HDRReferenceCondition
from .pipeline_ltx2_ic_lora import LTX2InContextPipeline, LTX2ReferenceCondition
from .pipeline_ltx2_image2video import LTX2ImageToVideoPipeline
from .pipeline_ltx2_latent_upsample import LTX2LatentUpsamplePipeline
from .vocoder import LTX2Vocoder, LTX2VocoderWithBWE
Expand Down
88 changes: 88 additions & 0 deletions src/diffusers/pipelines/ltx2/export_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from collections.abc import Iterator
from fractions import Fraction
from itertools import chain
from pathlib import Path
from typing import Callable

import numpy as np
import PIL.Image
Expand Down Expand Up @@ -189,3 +191,89 @@ def encode_video(
_write_audio(container, audio_stream, audio, audio_sample_rate)

container.close()


def simple_tone_map(x: np.ndarray) -> np.ndarray:
r"""
Applies a very simple tone-mapping function on (scene-referred) linear light which simply clips values above `1.0`
to `1.0`. This is what the original LTX-2.X code does, but you probably want to do some non-trivial tone-mapping to
make the sample look better.
"""
return np.clip(x, 0.0, 1.0)


# Adapted from ltx_pipelines.utils.medio_io._linear_to_srgb
# https://github.com/Lightricks/LTX-2/blob/41d924371612b692c0fd1e4d9d94c3dfb3c02cb3/packages/ltx-pipelines/src/ltx_pipelines/utils/media_io.py#L644
def linear_to_srgb(x: np.ndarray) -> np.ndarray:
r"""
Apply the sRGB (Rec.709) transfer function (OETF; IEC 61966-2-1) to a linear light image. Input values must be in
`[0, 1]`.
"""
return np.where(x <= 0.0031308, x * 12.92, 1.055 * np.power(x, 1.0 / 2.4) - 0.055)


def encode_hdr_tensor_to_mp4(
frames: torch.Tensor,
output_mp4: str | Path,
frame_rate: float,
tone_mapping_fn: Callable[[np.ndarray], np.ndarray] | None = None,
tone_map_in_rgb: bool = True,
crf: int = 18,
) -> None:
"""
Converts a linear HDR tensor (for example, as outputted by `LTX2HDRPipeline`) to a SDR `.mp4` file (specifically, a
sRGB-tonemapped H.264 `.mp4`).

Args:
frames (`torch.Tensor`):
A linear HDR tensors with RGB values in `[0, ∞)` of shape `(F, H, W, 3)`.
output_mp4 (`str` or `pathlib.Path`):
Output MP4 path.
frame_rate (`float`):
Frame rate for the output video.
tone_mapping_fn (`Callable[[np.ndarray], np.ndarray]`, *optional*, defaults to `None`):
An optional tone mapping function which takes a float32 NumPy array of shape `(H, W, 3)` containing linear
HDR values in `[0, ∞)` and returns tone-mapped linear values in `[0, 1]`. The sRGB transfer function (OETF)
is applied afterwards — do **not** pre-apply gamma inside this function. If `None`, defaults to
[`simple_tone_map`], which clips values above `1.0`. The channel ordering of the input array is controlled
by `tone_map_in_rgb`: RGB by default (matching the `LTX2HDRPipeline` output), or BGR when
`tone_map_in_rgb=False`. This is the opposite default to `encode_exr_sequence_to_mp4`.
tone_map_in_rgb (`bool`, *optional*, defaults to `True`):
When `True` (default), frames are passed as RGB to `tone_mapping_fn`, and the output frame is tagged as
`rgb24`. Use this when `tone_mapping_fn` expects RGB input (e.g. operators from `colour-science`). When
`False`, the frames first have their channels flipped to BGR, which is the native format for
`opencv-python` tone mappers (e.g. `cv2.createTonemapReinhard().process`). Note that this is the opposite
default to `encode_exr_sequence_to_mp4`.
crf (`int`, *optional*, defaults to `18`):
libx264 CRF quality factor. Lower values produce higher quality.
"""
frames_np = frames.cpu().float().numpy()

container = av.open(str(output_mp4), mode="w")
stream = container.add_stream("libx264", rate=Fraction(frame_rate).limit_denominator(1000))
stream.pix_fmt = "yuv420p"
stream.options = {"crf": str(crf), "movflags": "+faststart"}

pix_fmt = "rgb24" if tone_map_in_rgb else "bgr24"
if tone_mapping_fn is None:
tone_mapping_fn = simple_tone_map

try:
for i, hdr in enumerate(frames_np):
if not tone_map_in_rgb:
hdr = hdr[..., ::-1]
hdr_mapped = tone_mapping_fn(hdr)
sdr = linear_to_srgb(np.maximum(hdr_mapped, 0.0))
out8 = (sdr * 255.0 + 0.5).astype(np.uint8)

if i == 0:
stream.height, stream.width = out8.shape[:2]

frame = av.VideoFrame.from_ndarray(out8, format=pix_fmt)
for packet in stream.encode(frame):
container.mux(packet)

for packet in stream.encode():
container.mux(packet)
finally:
container.close()
150 changes: 150 additions & 0 deletions src/diffusers/pipelines/ltx2/image_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
# Copyright 2025 Lightricks and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
import torch.nn.functional as F

from ...configuration_utils import register_to_config
from ...video_processor import VideoProcessor


class LTX2VideoHDRProcessor(VideoProcessor):
r"""
Video processor for the LTX-2 HDR IC-LoRA pipeline.

Inherits standard video preprocessing from [`VideoProcessor`] and additionally supports:

- `preprocess_reference_video_hdr`: aspect-ratio-preserving resize followed by reflect-padding to the target size.
For LDR (SDR Rec.709) reference videos, `LogC3.compress_ldr` is an identity clamp, so the numerical output is
equivalent to the standard [-1, 1] normalization used by [`VideoProcessor.preprocess_video`] — only the resize
strategy differs (reflect-pad vs center-crop).
- `postprocess_hdr_video`: applies the LogC3 inverse transform to the VAE's decoded output, mapping `[0, 1]` →
linear HDR `[0, ∞)`.

Args:
vae_scale_factor (`int`, *optional*, defaults to `32`):
VAE (spatial) scale factor for the LTX-2 video VAE.
resample (`str`, *optional*, defaults to `"bilinear"`):
Resampling filter used by the base [`VaeImageProcessor`] for PIL/tensor resizing.
hdr_transform (`str`, *optional*, defaults to `"logc3"`):
HDR transform identifier. Only `"logc3"` (ARRI EI 800) is currently supported.
"""

# LogC3 (ARRI EI 800) coefficients, ported from `ltx_core.hdr.LogC3`.
_LOGC3_A = 5.555556
_LOGC3_B = 0.052272
_LOGC3_C = 0.247190
_LOGC3_D = 0.385537
_LOGC3_E = 5.367655
_LOGC3_F = 0.092809
_LOGC3_CUT = 0.010591

@register_to_config
def __init__(
self,
vae_scale_factor: int = 32,
resample: str = "bilinear",
hdr_transform: str = "logc3",
):
super().__init__(
do_resize=True,
vae_scale_factor=vae_scale_factor,
resample=resample,
)
if hdr_transform != "logc3":
raise ValueError(f"Unsupported HDR transform {hdr_transform!r}. Only 'logc3' is supported.")

@classmethod
def _logc3_decompress(cls, logc: torch.Tensor) -> torch.Tensor:
r"""Decompress LogC3 `[0, 1]` → linear HDR `[0, ∞)`."""
logc = torch.clamp(logc, 0.0, 1.0)
cut_log = cls._LOGC3_E * cls._LOGC3_CUT + cls._LOGC3_F
lin_from_log = (torch.pow(10.0, (logc - cls._LOGC3_D) / cls._LOGC3_C) - cls._LOGC3_B) / cls._LOGC3_A
lin_from_lin = (logc - cls._LOGC3_F) / cls._LOGC3_E
return torch.where(logc >= cut_log, lin_from_log, lin_from_lin)

@staticmethod
def _resize_and_reflect_pad_video(video: torch.Tensor, height: int, width: int) -> torch.Tensor:
r"""
Resize a video tensor preserving aspect ratio, then reflect-pad to the exact target dimensions.

Args:
video (`torch.Tensor`): Input of shape `(B, C, F, H, W)`.
height (`int`), width (`int`): Target spatial dimensions.

Returns:
`torch.Tensor`: Resized and padded video of shape `(B, C, F, height, width)`.
"""
b, c, f, src_h, src_w = video.shape

if height >= src_h and width >= src_w:
new_h, new_w = src_h, src_w
else:
scale = min(height / src_h, width / src_w)
new_h = round(src_h * scale)
new_w = round(src_w * scale)
# (B, C, F, H, W) → (B, F, C, H, W) → (B*F, C, H, W) for 2D per-frame interpolation.
video = video.permute(0, 2, 1, 3, 4).reshape(b * f, c, src_h, src_w)
video = F.interpolate(video, size=(new_h, new_w), mode="bilinear", align_corners=False)
video = video.reshape(b, f, c, new_h, new_w).permute(0, 2, 1, 3, 4)

pad_bottom = height - new_h
pad_right = width - new_w
if pad_bottom > 0 or pad_right > 0:
# `reflect` pad requires the pad amount to be strictly less than the corresponding input dim.
pad_mode = "reflect" if pad_bottom < new_h and pad_right < new_w else "replicate"
video = video.permute(0, 2, 1, 3, 4).reshape(b * f, c, new_h, new_w)
video = F.pad(video, (0, pad_right, 0, pad_bottom), mode=pad_mode)
video = video.reshape(b, f, c, height, width).permute(0, 2, 1, 3, 4)

return video

def preprocess_reference_video_hdr(
self,
video,
height: int,
width: int,
) -> torch.Tensor:
r"""
Preprocess a reference (SDR) video for HDR IC-LoRA conditioning.

Runs the input through the standard video preprocessing (normalization to `[-1, 1]`) without resizing, then
applies reflect-pad resize to the target dimensions. For LDR inputs this is numerically equivalent to
`load_video_conditioning_hdr` in the reference implementation (since `LogC3.compress_ldr` is an identity clamp
on `[0, 1]` inputs).

Args:
video: Input accepted by `VideoProcessor.preprocess_video` (list of PIL images, 4D/5D tensor/array, etc.).
height (`int`), width (`int`): Target spatial dimensions.

Returns:
`torch.Tensor`: Preprocessed video of shape `(B, C, F, height, width)` with values in `[-1, 1]`.
"""
video = self.preprocess_video(video, height=None, width=None) # (B, C, F, src_h, src_w) in [-1, 1]
video = self._resize_and_reflect_pad_video(video, height, width)
return video

def postprocess_hdr_video(self, video: torch.Tensor) -> torch.Tensor:
r"""
Postprocess the VAE's decoded output to linear HDR.

Args:
video (`torch.Tensor`):
VAE decoded output in VAE range `[-1, 1]`, shape `(B, C, F, H, W)`.

Returns:
`torch.Tensor`: Linear HDR video `[0, ∞)`, shape `(B, C, F, H, W)`, dtype `float32`.
"""
video = (video.float() / 2.0 + 0.5).clamp(0.0, 1.0)
return self._logc3_decompress(video)
Loading
Loading