Skip to content

Commit d3607a8

Browse files
authored
feat: Add downscaled IC-LoRA support to LTXVAddGuide (CORE-102) (Comfy-Org#13896)
1 parent 5d5a455 commit d3607a8

3 files changed

Lines changed: 108 additions & 9 deletions

File tree

comfy/sd.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@
7979

8080
import comfy.ldm.flux.redux
8181

82-
def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
82+
def load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=None):
8383
key_map = {}
8484
if model is not None:
8585
key_map = comfy.lora.model_lora_keys_unet(model.model, key_map)
@@ -91,13 +91,17 @@ def load_lora_for_models(model, clip, lora, strength_model, strength_clip):
9191
if model is not None:
9292
new_modelpatcher = model.clone()
9393
k = new_modelpatcher.add_patches(loaded, strength_model)
94+
if lora_metadata:
95+
new_modelpatcher.set_attachments("lora_metadata", lora_metadata)
9496
else:
9597
k = ()
9698
new_modelpatcher = None
9799

98100
if clip is not None:
99101
new_clip = clip.clone()
100102
k1 = new_clip.add_patches(loaded, strength_clip)
103+
if lora_metadata:
104+
new_clip.patcher.set_attachments("lora_metadata", lora_metadata)
101105
else:
102106
k1 = ()
103107
new_clip = None

comfy_extras/nodes_lt.py

Lines changed: 98 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,49 @@
1414
from comfy.ldm.lightricks.symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords
1515
from comfy_api.latest import ComfyExtension, io
1616

17+
ICLoRAParameters = io.Custom("IC_LORA_PARAMETERS")
18+
19+
20+
class GetICLoRAParameters(io.ComfyNode):
21+
@classmethod
22+
def define_schema(cls):
23+
return io.Schema(
24+
node_id="GetICLoRAParameters",
25+
display_name="Get IC-LoRA Parameters",
26+
description="Extracts IC-LoRA parameters from the safetensors metadata of a LoRA-loaded "
27+
"model and outputs them for LTXVAddGuide (eg. reference_downscale_factor).",
28+
category="conditioning/video_models",
29+
search_aliases=["ic-lora", "ic lora", "iclora", "downscale factor", "reference downscale"],
30+
inputs=[
31+
io.Model.Input(
32+
"iclora_model",
33+
tooltip="Direct output from a LoRA Loader for the specific IC-LoRA "
34+
"from which to extract the metadata.",
35+
),
36+
],
37+
outputs=[
38+
ICLoRAParameters.Output(
39+
"iclora_parameters",
40+
tooltip="IC-LoRA parameters extracted from the LoRA metadata "
41+
"(eg. reference_downscale_factor). Connect to LTXVAddGuide "
42+
"if the LoRA requires special handling of the guides.",
43+
),
44+
],
45+
)
46+
47+
@classmethod
48+
def execute(cls, iclora_model) -> io.NodeOutput:
49+
metadata = iclora_model.get_attachment("lora_metadata")
50+
factor = 1
51+
if metadata:
52+
try:
53+
factor = max(1, round(float(metadata.get("reference_downscale_factor", 1))))
54+
except (TypeError, ValueError):
55+
factor = 1
56+
parameters = {"reference_downscale_factor": factor}
57+
return io.NodeOutput(parameters)
58+
59+
1760
class EmptyLTXVLatentVideo(io.ComfyNode):
1861
@classmethod
1962
def define_schema(cls):
@@ -220,6 +263,14 @@ def define_schema(cls):
220263
"down to the nearest multiple of 8. Negative values are counted from the end of the video.",
221264
),
222265
io.Float.Input("strength", default=1.0, min=0.0, max=10.0, step=0.01),
266+
ICLoRAParameters.Input(
267+
"iclora_parameters",
268+
optional=True,
269+
tooltip="Optional IC-LoRA parameters from a Get IC-LoRA Parameters node. "
270+
"Used for adjusting guide processing as required by certain IC-LoRAs "
271+
"(eg. those with a reference_downscale_factor > 1). "
272+
"When chained, each LTXVAddGuide uses only the parameters connected to it.",
273+
),
223274
],
224275
outputs=[
225276
io.Conditioning.Output(display_name="positive"),
@@ -229,14 +280,41 @@ def define_schema(cls):
229280
)
230281

231282
@classmethod
232-
def encode(cls, vae, latent_width, latent_height, images, scale_factors):
283+
def encode(cls, vae, latent_width, latent_height, images, scale_factors, latent_downscale_factor=1):
233284
time_scale_factor, width_scale_factor, height_scale_factor = scale_factors
234285
images = images[:(images.shape[0] - 1) // time_scale_factor * time_scale_factor + 1]
235-
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), latent_width * width_scale_factor, latent_height * height_scale_factor, "bilinear", crop="center").movedim(1, -1)
286+
target_width = int(latent_width * width_scale_factor / latent_downscale_factor)
287+
target_height = int(latent_height * height_scale_factor / latent_downscale_factor)
288+
pixels = comfy.utils.common_upscale(images.movedim(-1, 1), target_width, target_height, "bilinear", crop="center").movedim(1, -1)
236289
encode_pixels = pixels[:, :, :, :3]
237290
t = vae.encode(encode_pixels)
238291
return encode_pixels, t
239292

293+
@classmethod
294+
def dilate_latent(cls, guide_latent, latent_downscale_factor):
295+
if latent_downscale_factor <= 1:
296+
return guide_latent, None
297+
scale = int(latent_downscale_factor)
298+
dilated_shape = guide_latent.shape[:3] + (guide_latent.shape[3] * scale, guide_latent.shape[4] * scale)
299+
dilated = torch.zeros(dilated_shape, device=guide_latent.device, dtype=guide_latent.dtype)
300+
dilated[..., ::scale, ::scale] = guide_latent
301+
dilated_mask = torch.full(
302+
(dilated.shape[0], 1, dilated.shape[2], dilated.shape[3], dilated.shape[4]),
303+
-1.0, device=guide_latent.device, dtype=guide_latent.dtype,
304+
)
305+
dilated_mask[..., ::scale, ::scale] = 1.0
306+
return dilated, dilated_mask
307+
308+
@classmethod
309+
def get_reference_downscale_factor(cls, iclora_parameters):
310+
if not iclora_parameters:
311+
return 1
312+
try:
313+
factor = max(1, round(float(iclora_parameters.get("reference_downscale_factor", 1))))
314+
except (TypeError, ValueError):
315+
factor = 1
316+
return factor
317+
240318
@classmethod
241319
def get_latent_index(cls, cond, latent_length, guide_length, frame_idx, scale_factors):
242320
time_scale_factor, _, _ = scale_factors
@@ -332,13 +410,21 @@ def replace_latent_frames(cls, latent_image, noise_mask, guiding_latent, latent_
332410
return latent_image, noise_mask
333411

334412
@classmethod
335-
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) -> io.NodeOutput:
413+
def execute(cls, positive, negative, vae, latent, image, frame_idx, strength, iclora_parameters=None) -> io.NodeOutput:
336414
scale_factors = vae.downscale_index_formula
337415
latent_image = latent["samples"]
338416
noise_mask = get_noise_mask(latent)
339417

340418
_, _, latent_length, latent_height, latent_width = latent_image.shape
341419

420+
latent_downscale_factor = cls.get_reference_downscale_factor(iclora_parameters)
421+
if latent_downscale_factor > 1:
422+
if latent_width % latent_downscale_factor != 0 or latent_height % latent_downscale_factor != 0:
423+
raise ValueError(
424+
f"Latent spatial size {latent_width}x{latent_height} must be divisible by "
425+
f"reference_downscale_factor {latent_downscale_factor} from the IC-LoRA parameters."
426+
)
427+
342428
# For mid-video multi-frame guides, prepend+strip a throwaway first frame so the VAE's "first latent = 1 pixel frame" asymmetry lands on the discarded slot
343429
time_scale_factor = scale_factors[0]
344430
num_frames_to_keep = ((image.shape[0] - 1) // time_scale_factor) * time_scale_factor + 1
@@ -351,12 +437,17 @@ def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) ->
351437
if not causal_fix:
352438
image = torch.cat([image[:1], image], dim=0)
353439

354-
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors)
440+
image, t = cls.encode(vae, latent_width, latent_height, image, scale_factors, latent_downscale_factor)
355441

356442
if not causal_fix:
357443
t = t[:, :, 1:, :, :]
358444
image = image[1:]
359445

446+
guide_latent_shape = list(t.shape[2:]) # pre-dilation [F, H, W] for spatial-mask downsampling
447+
guide_mask = None
448+
if latent_downscale_factor > 1:
449+
t, guide_mask = cls.dilate_latent(t, latent_downscale_factor)
450+
360451
frame_idx, latent_idx = cls.get_latent_index(positive, latent_length, len(image), frame_idx, scale_factors)
361452
assert latent_idx + t.shape[2] <= latent_length, "Conditioning frames exceed the length of the latent sequence."
362453

@@ -369,12 +460,13 @@ def execute(cls, positive, negative, vae, latent, image, frame_idx, strength) ->
369460
t,
370461
strength,
371462
scale_factors,
463+
guide_mask=guide_mask,
464+
latent_downscale_factor=latent_downscale_factor,
372465
causal_fix=causal_fix,
373466
)
374467

375468
# Track this guide for per-reference attention control.
376469
pre_filter_count = t.shape[2] * t.shape[3] * t.shape[4]
377-
guide_latent_shape = list(t.shape[2:]) # [F, H, W]
378470
positive, negative = _append_guide_attention_entry(
379471
positive, negative, pre_filter_count, guide_latent_shape, strength=strength,
380472
)
@@ -794,6 +886,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
794886
ModelSamplingLTXV,
795887
LTXVConditioning,
796888
LTXVScheduler,
889+
GetICLoRAParameters,
797890
LTXVAddGuide,
798891
LTXVPreprocess,
799892
LTXVCropGuides,

nodes.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -700,17 +700,19 @@ def load_lora(self, model, clip, lora_name, strength_model, strength_clip):
700700

701701
lora_path = folder_paths.get_full_path_or_raise("loras", lora_name)
702702
lora = None
703+
lora_metadata = None
703704
if self.loaded_lora is not None:
704705
if self.loaded_lora[0] == lora_path:
705706
lora = self.loaded_lora[1]
707+
lora_metadata = self.loaded_lora[2] if len(self.loaded_lora) > 2 else None
706708
else:
707709
self.loaded_lora = None
708710

709711
if lora is None:
710-
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
711-
self.loaded_lora = (lora_path, lora)
712+
lora, lora_metadata = comfy.utils.load_torch_file(lora_path, safe_load=True, return_metadata=True)
713+
self.loaded_lora = (lora_path, lora, lora_metadata)
712714

713-
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip)
715+
model_lora, clip_lora = comfy.sd.load_lora_for_models(model, clip, lora, strength_model, strength_clip, lora_metadata=lora_metadata)
714716
return (model_lora, clip_lora)
715717

716718
class LoraLoaderModelOnly(LoraLoader):

0 commit comments

Comments
 (0)