Skip to content

Commit 292814c

Browse files
authored
feat: Add optional attention_mask input to LTXVAddGuide (CORE-220) (Comfy-Org#13965)
1 parent 187e523 commit 292814c

1 file changed

Lines changed: 12 additions & 5 deletions

File tree

comfy_extras/nodes_lt.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def execute(cls, vae, image, latent, strength, bypass=False) -> io.NodeOutput:
175175
generate = execute # TODO: remove
176176

177177

178-
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0):
178+
def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_shape, strength=1.0, attention_mask=None):
179179
"""Append a guide_attention_entry to both positive and negative conditioning.
180180
181181
Each entry tracks one guide reference for per-reference attention control.
@@ -184,9 +184,10 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
184184
new_entry = {
185185
"pre_filter_count": pre_filter_count,
186186
"strength": strength,
187-
"pixel_mask": None,
187+
"pixel_mask": attention_mask.unsqueeze(0).unsqueeze(0) if attention_mask is not None else None, # reshape to (1, 1, F, H, W)
188188
"latent_shape": latent_shape,
189189
}
190+
190191
results = []
191192
for cond in (positive, negative):
192193
# Read existing entries from this specific conditioning
@@ -196,8 +197,7 @@ def _append_guide_attention_entry(positive, negative, pre_filter_count, latent_s
196197
if found is not None:
197198
existing = found
198199
break
199-
# Shallow copy and append (no deepcopy needed — entries contain
200-
# only scalars and None for pixel_mask at this call site).
200+
# Shallow copy only and append (pixel_mask is never mutated).
201201
entries = [*existing, new_entry]
202202
results.append(node_helpers.conditioning_set_values(
203203
cond, {"guide_attention_entries": entries}
@@ -263,6 +263,12 @@ def define_schema(cls):
263263
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
264264
),
265265
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
266+
io.Mask.Input(
267+
"attention_mask",
268+
optional=True,
269+
tooltip="Optional pixel-space spatial mask. Controls per-region "
270+
"conditioning influence via self-attention, multiplied by strength.",
271+
),
266272
ICLoRAParameters.Input(
267273
"iclora_parameters",
268274
optional=True,
@@ -410,7 +416,7 @@ def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_
410416
return latent_image, noise_mask
411417

412418
@classmethod
413-
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
419+
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, attention_mask=None, iclora_parameters=None) -> io.NodeOutput:
414420
scale_factors = vae.downscale_index_formula
415421
latent_image = latent["samples"]
416422
noise_mask = get_noise_mask(latent)
@@ -469,6 +475,7 @@ def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, ic
469475
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
470476
positive, negative = _append_guide_attention_entry(
471477
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
478+
attention_mask=attention_mask,
472479
)
473480

474481
return io.NodeOutput(positive, negative, {"samples": latent_image, "noise_mask": noise_mask})

0 commit comments

Comments
 (0)