@@ -358,6 +358,61 @@ def apply_split_rotary_emb(input_tensor, cos, sin):
358358 return output .swapaxes (1 , 2 ).reshape (B , T , - 1 ) if needs_reshape else output
359359
360360
361+ class GuideAttentionMask :
362+ """Holds the two per-group masks for LTXV guide self-attention.
363+ _attention_with_guide_mask splits queries into noisy and tracked-guide
364+ groups, so the largest mask is (1, 1, tracked_count, T).
365+ """
366+ __slots__ = ("guide_start" , "tracked_count" , "noisy_mask" , "tracked_mask" )
367+
368+ def __init__ (self , total_tokens , guide_start , tracked_count , tracked_weights ):
369+ device = tracked_weights .device
370+ dtype = tracked_weights .dtype
371+ finfo = torch .finfo (dtype )
372+
373+ pos = tracked_weights > 0
374+ log_w = torch .full_like (tracked_weights , finfo .min )
375+ log_w [pos ] = torch .log (tracked_weights [pos ].clamp (min = finfo .tiny ))
376+
377+ self .guide_start = guide_start
378+ self .tracked_count = tracked_count
379+
380+ self .noisy_mask = torch .zeros ((1 , 1 , 1 , total_tokens ), device = device , dtype = dtype )
381+ self .noisy_mask [:, :, :, guide_start :guide_start + tracked_count ] = log_w .view (1 , 1 , 1 , - 1 )
382+
383+ self .tracked_mask = torch .zeros ((1 , 1 , tracked_count , total_tokens ), device = device , dtype = dtype )
384+ self .tracked_mask [:, :, :, :guide_start ] = log_w .view (1 , 1 , - 1 , 1 )
385+
386+
387+ def _attention_with_guide_mask (q , k , v , heads , guide_mask , attn_precision , transformer_options ):
388+ """Apply the guide mask by partitioning Q into noisy and tracked-guide
389+ groups, so each group needs only its own sub-mask. Avoids materializing
390+ the (1,1,T,T) dense mask.
391+ """
392+ guide_start = guide_mask .guide_start
393+ tracked_end = guide_start + guide_mask .tracked_count
394+
395+ out = torch .empty_like (q )
396+
397+ if guide_start > 0 : # In practice currently guides are always after noise, guard for safety if this changes.
398+ out [:, :guide_start , :] = comfy .ldm .modules .attention .optimized_attention (
399+ q [:, :guide_start , :], k , v , heads , mask = guide_mask .noisy_mask ,
400+ attn_precision = attn_precision , transformer_options = transformer_options ,
401+ low_precision_attention = False , # sageattn mask support is unreliable
402+ )
403+ out [:, guide_start :tracked_end , :] = comfy .ldm .modules .attention .optimized_attention (
404+ q [:, guide_start :tracked_end , :], k , v , heads , mask = guide_mask .tracked_mask ,
405+ attn_precision = attn_precision , transformer_options = transformer_options ,
406+ low_precision_attention = False ,
407+ )
408+ if tracked_end < q .shape [1 ]: # Every guide token is tracked, and nothing comes after them, guard for safety if this changes.
409+ out [:, tracked_end :, :] = comfy .ldm .modules .attention .optimized_attention (
410+ q [:, tracked_end :, :], k , v , heads ,
411+ attn_precision = attn_precision , transformer_options = transformer_options ,
412+ )
413+ return out
414+
415+
361416class CrossAttention (nn .Module ):
362417 def __init__ (
363418 self ,
@@ -412,8 +467,10 @@ def forward(self, x, context=None, mask=None, pe=None, k_pe=None, transformer_op
412467
413468 if mask is None :
414469 out = comfy .ldm .modules .attention .optimized_attention (q , k , v , self .heads , attn_precision = self .attn_precision , transformer_options = transformer_options )
470+ elif isinstance (mask , GuideAttentionMask ):
471+ out = _attention_with_guide_mask (q , k , v , self .heads , mask , attn_precision = self .attn_precision , transformer_options = transformer_options )
415472 else :
416- out = comfy .ldm .modules .attention .optimized_attention_masked (q , k , v , self .heads , mask , attn_precision = self .attn_precision , transformer_options = transformer_options )
473+ out = comfy .ldm .modules .attention .optimized_attention (q , k , v , self .heads , mask = mask , attn_precision = self .attn_precision , transformer_options = transformer_options )
417474
418475 # Apply per-head gating if enabled
419476 if self .to_gate_logits is not None :
@@ -1063,7 +1120,9 @@ def _process_input(self, x, keyframe_idxs, denoise_mask, **kwargs):
10631120 additional_args ["resolved_guide_entries" ] = resolved_entries
10641121
10651122 keyframe_idxs = keyframe_idxs [..., kf_grid_mask , :]
1066- pixel_coords [:, :, - keyframe_idxs .shape [2 ]:, :] = keyframe_idxs
1123+
1124+ if keyframe_idxs .shape [2 ] > 0 : # Guard for the case of no keyframes surviving
1125+ pixel_coords [:, :, - keyframe_idxs .shape [2 ]:, :] = keyframe_idxs
10671126
10681127 # Total surviving guide tokens (all guides)
10691128 additional_args ["num_guide_tokens" ] = keyframe_idxs .shape [2 ]
@@ -1099,12 +1158,12 @@ def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
10991158 if not resolved_entries :
11001159 return None
11011160
1102- # Check if any attenuation is actually needed
1103- needs_attenuation = any (
1104- e ["strength" ] < 1.0 or e .get ("pixel_mask" ) is not None
1161+ # strength != 1.0 means we want to either attenuate (< 1) or amplify (> 1) guide attention.
1162+ needs_mask = any (
1163+ e ["strength" ] != 1.0 or e .get ("pixel_mask" ) is not None
11051164 for e in resolved_entries
11061165 )
1107- if not needs_attenuation :
1166+ if not needs_mask :
11081167 return None
11091168
11101169 # Build per-guide-token weights for all tracked guide tokens.
@@ -1159,16 +1218,11 @@ def _build_guide_self_attention_mask(self, x, transformer_options, merged_args):
11591218 # Concatenate per-token weights for all tracked guides
11601219 tracked_weights = torch .cat (all_weights , dim = 1 ) # (1, total_tracked)
11611220
1162- # Check if any weight is actually < 1.0 (otherwise no attenuation needed)
1163- if (tracked_weights > = 1.0 ).all ():
1221+ # Skip when every weight is exactly 1.0 (additive bias would be 0).
1222+ if (tracked_weights = = 1.0 ).all ():
11641223 return None
11651224
1166- # Build the mask: guide tokens are at the end of the sequence.
1167- # Tracked guides come first (in order), untracked follow.
1168- return self ._build_self_attention_mask (
1169- total_tokens , num_guide_tokens , total_tracked ,
1170- tracked_weights , guide_start , device , dtype ,
1171- )
1225+ return GuideAttentionMask (total_tokens , guide_start , total_tracked , tracked_weights )
11721226
11731227 @staticmethod
11741228 def _downsample_mask_to_latent (mask , f_lat , h_lat , w_lat ):
@@ -1234,45 +1288,6 @@ def _downsample_mask_to_latent(mask, f_lat, h_lat, w_lat):
12341288
12351289 return rearrange (latent_mask , "b 1 f h w -> b (f h w)" )
12361290
1237- @staticmethod
1238- def _build_self_attention_mask (total_tokens , num_guide_tokens , tracked_count ,
1239- tracked_weights , guide_start , device , dtype ):
1240- """Build a log-space additive self-attention bias mask.
1241-
1242- Attenuates attention between noisy tokens and tracked guide tokens.
1243- Untracked guide tokens (at the end of the guide portion) keep full attention.
1244-
1245- Args:
1246- total_tokens: Total sequence length.
1247- num_guide_tokens: Total guide tokens (all guides) at end of sequence.
1248- tracked_count: Number of tracked guide tokens (first in the guide portion).
1249- tracked_weights: (1, tracked_count) tensor, values in [0, 1].
1250- guide_start: Index where guide tokens begin in the sequence.
1251- device: Target device.
1252- dtype: Target dtype.
1253-
1254- Returns:
1255- (1, 1, total_tokens, total_tokens) additive bias mask.
1256- 0.0 = full attention, negative = attenuated, finfo.min = effectively fully masked.
1257- """
1258- finfo = torch .finfo (dtype )
1259- mask = torch .zeros ((1 , 1 , total_tokens , total_tokens ), device = device , dtype = dtype )
1260- tracked_end = guide_start + tracked_count
1261-
1262- # Convert weights to log-space bias
1263- w = tracked_weights .to (device = device , dtype = dtype ) # (1, tracked_count)
1264- log_w = torch .full_like (w , finfo .min )
1265- positive_mask = w > 0
1266- if positive_mask .any ():
1267- log_w [positive_mask ] = torch .log (w [positive_mask ].clamp (min = finfo .tiny ))
1268-
1269- # noisy → tracked guides: each noisy row gets the same per-guide weight
1270- mask [:, :, :guide_start , guide_start :tracked_end ] = log_w .view (1 , 1 , 1 , - 1 )
1271- # tracked guides → noisy: each guide row broadcasts its weight across noisy cols
1272- mask [:, :, guide_start :tracked_end , :guide_start ] = log_w .view (1 , 1 , - 1 , 1 )
1273-
1274- return mask
1275-
12761291 def _process_transformer_blocks (self , x , context , attention_mask , timestep , pe , transformer_options = {}, self_attention_mask = None , ** kwargs ):
12771292 """Process transformer blocks for LTXV."""
12781293 patches_replace = transformer_options .get ("patches_replace" , {})
0 commit comments