Skip to content

Commit 33ce449

Browse files
authored
Reduce LTX2.3 peak VRAM when guide_mask is in use (CORE-166) (Comfy-Org#13735)
- Reduce peak VRAM by handling self_attn_mask more efficiently - Fallback to SDPA when self_attention_mask is used
1 parent 04856ac commit 33ce449

3 files changed

Lines changed: 106 additions & 89 deletions

File tree

comfy/ldm/lightricks/av_model.py

Lines changed: 35 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -22,26 +22,25 @@ class CompressedTimestep:
2222
"""Store video timestep embeddings in compressed form using per-frame indexing."""
2323
__slots__ = ('data', 'batch_size', 'num_frames', 'patches_per_frame', 'feature_dim')
2424

25-
def __init__(self, tensor: torch.Tensor, patches_per_frame: int):
25+
def __init__(self, tensor: torch.Tensor, patches_per_frame: int, per_frame: bool = False):
2626
"""
27-
tensor: [batch_size, num_tokens, feature_dim] tensor where num_tokens = num_frames * patches_per_frame
28-
patches_per_frame: Number of spatial patches per frame (height * width in latent space), or None to disable compression
27+
tensor: [batch, num_tokens, feature_dim] (per-token, default) or
28+
[batch, num_frames, feature_dim] (per_frame=True, already compressed).
29+
patches_per_frame: spatial patches per frame; pass None to disable compression.
2930
"""
30-
self.batch_size, num_tokens, self.feature_dim = tensor.shape
31-
32-
# Check if compression is valid (num_tokens must be divisible by patches_per_frame)
33-
if patches_per_frame is not None and num_tokens % patches_per_frame == 0 and num_tokens >= patches_per_frame:
31+
self.batch_size, n, self.feature_dim = tensor.shape
32+
if per_frame:
3433
self.patches_per_frame = patches_per_frame
35-
self.num_frames = num_tokens // patches_per_frame
36-
37-
# Reshape to [batch, frames, patches_per_frame, feature_dim] and store one value per frame
38-
# All patches in a frame are identical, so we only keep the first one
39-
reshaped = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)
40-
self.data = reshaped[:, :, 0, :].contiguous() # [batch, frames, feature_dim]
34+
self.num_frames = n
35+
self.data = tensor
36+
elif patches_per_frame is not None and n >= patches_per_frame and n % patches_per_frame == 0:
37+
self.patches_per_frame = patches_per_frame
38+
self.num_frames = n // patches_per_frame
39+
# All patches in a frame are identical — keep only the first.
40+
self.data = tensor.view(self.batch_size, self.num_frames, patches_per_frame, self.feature_dim)[:, :, 0, :].contiguous()
4141
else:
42-
# Not divisible or too small - store directly without compression
4342
self.patches_per_frame = 1
44-
self.num_frames = num_tokens
43+
self.num_frames = n
4544
self.data = tensor
4645

4746
def expand(self):
@@ -716,32 +715,35 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
716715

717716
def _prepare_timestep(self, timestep, batch_size, hidden_dtype, **kwargs):
718717
"""Prepare timestep embeddings."""
719-
# TODO: some code reuse is needed here.
720718
grid_mask = kwargs.get("grid_mask", None)
721-
if grid_mask is not None:
722-
timestep = timestep[:, grid_mask]
719+
orig_shape = kwargs.get("orig_shape")
720+
has_spatial_mask = kwargs.get("has_spatial_mask", None)
721+
v_patches_per_frame = None
722+
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
723+
v_patches_per_frame = orig_shape[3] * orig_shape[4]
723724

724-
timestep_scaled = timestep * self.timestep_scale_multiplier
725+
# Used by compute_prompt_timestep and the audio cross-attention paths.
726+
timestep_scaled = (timestep[:, grid_mask] if grid_mask is not None else timestep) * self.timestep_scale_multiplier
727+
728+
# When patches in a frame share a timestep (no spatial mask), project one row per frame instead of one per token
729+
per_frame_path = v_patches_per_frame is not None and (timestep.numel() // batch_size) % v_patches_per_frame == 0
730+
if per_frame_path:
731+
per_frame = timestep.reshape(batch_size, -1, v_patches_per_frame)[:, :, 0]
732+
if grid_mask is not None:
733+
# All-or-nothing per frame when has_spatial_mask=False.
734+
per_frame = per_frame[:, grid_mask[::v_patches_per_frame]]
735+
ts_input = per_frame * self.timestep_scale_multiplier
736+
else:
737+
ts_input = timestep_scaled
725738

726739
v_timestep, v_embedded_timestep = self.adaln_single(
727-
timestep_scaled.flatten(),
740+
ts_input.flatten(),
728741
{"resolution": None, "aspect_ratio": None},
729742
batch_size=batch_size,
730743
hidden_dtype=hidden_dtype,
731744
)
732-
733-
# Calculate patches_per_frame from orig_shape: [batch, channels, frames, height, width]
734-
# Video tokens are arranged as (frames * height * width), so patches_per_frame = height * width
735-
orig_shape = kwargs.get("orig_shape")
736-
has_spatial_mask = kwargs.get("has_spatial_mask", None)
737-
v_patches_per_frame = None
738-
if not has_spatial_mask and orig_shape is not None and len(orig_shape) == 5:
739-
# orig_shape[3] = height, orig_shape[4] = width (in latent space)
740-
v_patches_per_frame = orig_shape[3] * orig_shape[4]
741-
742-
# Reshape to [batch_size, num_tokens, dim] and compress for storage
743-
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
744-
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
745+
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
746+
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame, per_frame=per_frame_path)
745747

746748
v_prompt_timestep = compute_prompt_timestep(
747749
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype

comfy/ldm/lightricks/model.py

Lines changed: 68 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,61 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
358358
return output.swapaxes(1, 2).reshape(B, T, -1) if needs_reshape else output
359359

360360

361+
class GuideAttentionMask:
362+
"""Holds the two per-group masks for LTXV guide self-attention.
363+
_attention_with_guide_mask splits queries into noisy and tracked-guide
364+
groups, so the largest mask is (1, 1, tracked_count, T).
365+
"""
366+
__slots__ = ("guide_start", "tracked_count", "noisy_mask", "tracked_mask")
367+
368+
def __init__(self, total_tokens, guide_start, tracked_count, tracked_weights):
369+
device = tracked_weights.device
370+
dtype = tracked_weights.dtype
371+
finfo = torch.finfo(dtype)
372+
373+
pos = tracked_weights > 0
374+
log_w = torch.full_like(tracked_weights, finfo.min)
375+
log_w[pos] = torch.log(tracked_weights[pos].clamp(min=finfo.tiny))
376+
377+
self.guide_start = guide_start
378+
self.tracked_count = tracked_count
379+
380+
self.noisy_mask = torch.zeros((1, 1, 1, total_tokens), device=device, dtype=dtype)
381+
self.noisy_mask[:, :, :, guide_start:guide_start + tracked_count] = log_w.view(1, 1, 1, -1)
382+
383+
self.tracked_mask = torch.zeros((1, 1, tracked_count, total_tokens), device=device, dtype=dtype)
384+
self.tracked_mask[:, :, :, :guide_start] = log_w.view(1, 1, -1, 1)
385+
386+
387+
def _attention_with_guide_mask(q, k, v, heads, guide_mask, attn_precision, transformer_options):
388+
"""Apply the guide mask by partitioning Q into noisy and tracked-guide
389+
groups, so each group needs only its own sub-mask. Avoids materializing
390+
the (1,1,T,T) dense mask.
391+
"""
392+
guide_start = guide_mask.guide_start
393+
tracked_end = guide_start + guide_mask.tracked_count
394+
395+
out = torch.empty_like(q)
396+
397+
if guide_start > 0: # In practice currently guides are always after noise, guard for safety if this changes.
398+
out[:, :guide_start, :] = comfy.ldm.modules.attention.optimized_attention(
399+
q[:, :guide_start, :], k, v, heads, mask=guide_mask.noisy_mask,
400+
attn_precision=attn_precision, transformer_options=transformer_options,
401+
low_precision_attention=False, # sageattn mask support is unreliable
402+
)
403+
out[:, guide_start:tracked_end, :] = comfy.ldm.modules.attention.optimized_attention(
404+
q[:, guide_start:tracked_end, :], k, v, heads, mask=guide_mask.tracked_mask,
405+
attn_precision=attn_precision, transformer_options=transformer_options,
406+
low_precision_attention=False,
407+
)
408+
if tracked_end < q.shape[1]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
409+
out[:, tracked_end:, :] = comfy.ldm.modules.attention.optimized_attention(
410+
q[:, tracked_end:, :], k, v, heads,
411+
attn_precision=attn_precision, transformer_options=transformer_options,
412+
)
413+
return out
414+
415+
361416
class CrossAttention(nn.Module):
362417
def __init__(
363418
self,
@@ -412,8 +467,10 @@ def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_op
412467

413468
if mask is None:
414469
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
470+
elif isinstance(mask, GuideAttentionMask):
471+
out = _attention_with_guide_mask(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
415472
else:
416-
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
473+
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, mask=mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
417474

418475
# Apply per-head gating if enabled
419476
if self.to_gate_logits is not None:
@@ -1063,7 +1120,9 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
10631120
additional_args["resolved_guide_entries"] = resolved_entries
10641121

10651122
keyframe_idxs = keyframe_idxs[..., kf_grid_mask, :]
1066-
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
1123+
1124+
if keyframe_idxs.shape[2] > 0: # Guard for the case of no keyframes surviving
1125+
pixel_coords[:, :, -keyframe_idxs.shape[2]:, :] = keyframe_idxs
10671126

10681127
# Total surviving guide tokens (all guides)
10691128
additional_args["num_guide_tokens"] = keyframe_idxs.shape[2]
@@ -1099,12 +1158,12 @@ def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
10991158
if not resolved_entries:
11001159
return None
11011160

1102-
# Check if any attenuation is actually needed
1103-
needs_attenuation = any(
1104-
e["strength"] < 1.0 or e.get("pixel_mask") is not None
1161+
# strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
1162+
needs_mask = any(
1163+
e["strength"] != 1.0 or e.get("pixel_mask") is not None
11051164
for e in resolved_entries
11061165
)
1107-
if not needs_attenuation:
1166+
if not needs_mask:
11081167
return None
11091168

11101169
# Build per-guide-token weights for all tracked guide tokens.
@@ -1159,16 +1218,11 @@ def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
11591218
# Concatenate per-token weights for all tracked guides
11601219
tracked_weights = torch.cat(all_weights, dim=1) # (1, total_tracked)
11611220

1162-
# Check if any weight is actually < 1.0 (otherwise no attenuation needed)
1163-
if (tracked_weights >= 1.0).all():
1221+
# Skip when every weight is exactly 1.0 (additive bias would be 0).
1222+
if (tracked_weights == 1.0).all():
11641223
return None
11651224

1166-
# Build the mask: guide tokens are at the end of the sequence.
1167-
# Tracked guides come first (in order), untracked follow.
1168-
return self._build_self_attention_mask(
1169-
total_tokens, num_guide_tokens, total_tracked,
1170-
tracked_weights, guide_start, device, dtype,
1171-
)
1225+
return GuideAttentionMask(total_tokens, guide_start, total_tracked, tracked_weights)
11721226

11731227
@staticmethod
11741228
def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
@@ -1234,45 +1288,6 @@ def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
12341288

12351289
return rearrange(latent_mask, "b 1 f h w -> b (f h w)")
12361290

1237-
@staticmethod
1238-
def _build_self_attention_mask(total_tokens, num_guide_tokens, tracked_count,
1239-
tracked_weights, guide_start, device, dtype):
1240-
"""Build a log-space additive self-attention bias mask.
1241-
1242-
Attenuates attention between noisy tokens and tracked guide tokens.
1243-
Untracked guide tokens (at the end of the guide portion) keep full attention.
1244-
1245-
Args:
1246-
total_tokens: Total sequence length.
1247-
num_guide_tokens: Total guide tokens (all guides) at end of sequence.
1248-
tracked_count: Number of tracked guide tokens (first in the guide portion).
1249-
tracked_weights: (1, tracked_count) tensor, values in [0, 1].
1250-
guide_start: Index where guide tokens begin in the sequence.
1251-
device: Target device.
1252-
dtype: Target dtype.
1253-
1254-
Returns:
1255-
(1, 1, total_tokens, total_tokens) additive bias mask.
1256-
0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
1257-
"""
1258-
finfo = torch.finfo(dtype)
1259-
mask = torch.zeros((1, 1, total_tokens, total_tokens), device=device, dtype=dtype)
1260-
tracked_end = guide_start + tracked_count
1261-
1262-
# Convert weights to log-space bias
1263-
w = tracked_weights.to(device=device, dtype=dtype) # (1, tracked_count)
1264-
log_w = torch.full_like(w, finfo.min)
1265-
positive_mask = w > 0
1266-
if positive_mask.any():
1267-
log_w[positive_mask] = torch.log(w[positive_mask].clamp(min=finfo.tiny))
1268-
1269-
# noisy → tracked guides: each noisy row gets the same per-guide weight
1270-
mask[:, :, :guide_start, guide_start:tracked_end] = log_w.view(1, 1, 1, -1)
1271-
# tracked guides → noisy: each guide row broadcasts its weight across noisy cols
1272-
mask[:, :, guide_start:tracked_end, :guide_start] = log_w.view(1, 1, -1, 1)
1273-
1274-
return mask
1275-
12761291
def _process_transformer_blocks(self, x, context, attention_mask, timestep, pe, transformer_options={}, self_attention_mask=None, **kwargs):
12771292
"""Process transformer blocks for LTXV."""
12781293
patches_replace = transformer_options.get("patches_replace", {})

comfy_extras/nodes_lt.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def define_schema(cls):
219219
"For videos with 9+ frames, frame_idx must be divisible by 8, otherwise it will be rounded "
220220
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
221221
),
222-
io.Float.Input("strength", default=1.0, min=0.0, max=1.0, step=0.01),
222+
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
223223
],
224224
outputs=[
225225
io.Conditioning.Output(display_name="positive"),
@@ -298,7 +298,7 @@ def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask
298298
else:
299299
mask = torch.full(
300300
(noise_mask.shape[0], 1, guiding_latent.shape[2], noise_mask.shape[3], noise_mask.shape[4]),
301-
1.0 - strength,
301+
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
302302
dtype=noise_mask.dtype,
303303
device=noise_mask.device,
304304
)
@@ -318,7 +318,7 @@ def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_
318318

319319
mask = torch.full(
320320
(noise_mask.shape[0], 1, cond_length, 1, 1),
321-
1.0 - strength,
321+
max(0.0, 1.0 - strength), # clamp here to amplify only via the attention mask
322322
dtype=noise_mask.dtype,
323323
device=noise_mask.device,
324324
)

0 commit comments

Comments
 (0)