Skip to content

Commit 44f1246

Browse files
Support flux 2 klein kv cache model: Use the FluxKVCache node. (Comfy-Org#12905)
1 parent 8f9ea49 commit 44f1246

2 files changed

Lines changed: 129 additions & 11 deletions

File tree

comfy/ldm/flux/model.py

Lines changed: 65 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,22 @@ class FluxParams:
4444
txt_norm: bool = False
4545

4646

47+
def invert_slices(slices, length):
48+
sorted_slices = sorted(slices)
49+
result = []
50+
current = 0
51+
52+
for start, end in sorted_slices:
53+
if current < start:
54+
result.append((current, start))
55+
current = max(current, end)
56+
57+
if current < length:
58+
result.append((current, length))
59+
60+
return result
61+
62+
4763
class Flux(nn.Module):
4864
"""
4965
Transformer model for flow matching on sequences.
@@ -138,6 +154,7 @@ def forward_orig(
138154
y: Tensor,
139155
guidance: Tensor = None,
140156
control = None,
157+
timestep_zero_index=None,
141158
transformer_options={},
142159
attn_mask: Tensor = None,
143160
) -> Tensor:
@@ -164,10 +181,6 @@ def forward_orig(
164181
txt = self.txt_norm(txt)
165182
txt = self.txt_in(txt)
166183

167-
vec_orig = vec
168-
if self.params.global_modulation:
169-
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
170-
171184
if "post_input" in patches:
172185
for p in patches["post_input"]:
173186
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
@@ -182,6 +195,24 @@ def forward_orig(
182195
else:
183196
pe = None
184197

198+
vec_orig = vec
199+
txt_vec = vec
200+
extra_kwargs = {}
201+
if timestep_zero_index is not None:
202+
modulation_dims = []
203+
batch = vec.shape[0] // 2
204+
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
205+
invert = invert_slices(timestep_zero_index, img.shape[1])
206+
for s in invert:
207+
modulation_dims.append((s[0], s[1], 0))
208+
for s in timestep_zero_index:
209+
modulation_dims.append((s[0], s[1], 1))
210+
extra_kwargs["modulation_dims_img"] = modulation_dims
211+
txt_vec = vec[:batch]
212+
213+
if self.params.global_modulation:
214+
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
215+
185216
blocks_replace = patches_replace.get("dit", {})
186217
transformer_options["total_blocks"] = len(self.double_blocks)
187218
transformer_options["block_type"] = "double"
@@ -195,7 +226,8 @@ def block_wrap(args):
195226
vec=args["vec"],
196227
pe=args["pe"],
197228
attn_mask=args.get("attn_mask"),
198-
transformer_options=args.get("transformer_options"))
229+
transformer_options=args.get("transformer_options"),
230+
**extra_kwargs)
199231
return out
200232

201233
out = blocks_replace[("double_block", i)]({"img": img,
@@ -213,7 +245,8 @@ def block_wrap(args):
213245
vec=vec,
214246
pe=pe,
215247
attn_mask=attn_mask,
216-
transformer_options=transformer_options)
248+
transformer_options=transformer_options,
249+
**extra_kwargs)
217250

218251
if control is not None: # Controlnet
219252
control_i = control.get("input")
@@ -230,6 +263,12 @@ def block_wrap(args):
230263
if self.params.global_modulation:
231264
vec, _ = self.single_stream_modulation(vec_orig)
232265

266+
extra_kwargs = {}
267+
if timestep_zero_index is not None:
268+
lambda a: 0 if a == 0 else a + txt.shape[1]
269+
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
270+
extra_kwargs["modulation_dims"] = modulation_dims_combined
271+
233272
transformer_options["total_blocks"] = len(self.single_blocks)
234273
transformer_options["block_type"] = "single"
235274
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@@ -242,7 +281,8 @@ def block_wrap(args):
242281
vec=args["vec"],
243282
pe=args["pe"],
244283
attn_mask=args.get("attn_mask"),
245-
transformer_options=args.get("transformer_options"))
284+
transformer_options=args.get("transformer_options"),
285+
**extra_kwargs)
246286
return out
247287

248288
out = blocks_replace[("single_block", i)]({"img": img,
@@ -253,7 +293,7 @@ def block_wrap(args):
253293
{"original_block": block_wrap})
254294
img = out["img"]
255295
else:
256-
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
296+
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
257297

258298
if control is not None: # Controlnet
259299
control_o = control.get("output")
@@ -264,7 +304,11 @@ def block_wrap(args):
264304

265305
img = img[:, txt.shape[1] :, ...]
266306

267-
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
307+
extra_kwargs = {}
308+
if timestep_zero_index is not None:
309+
extra_kwargs["modulation_dims"] = modulation_dims
310+
311+
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
268312
return img
269313

270314
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@@ -312,13 +356,16 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
312356
w_len = ((w_orig + (patch_size // 2)) // patch_size)
313357
img, img_ids = self.process_img(x, transformer_options=transformer_options)
314358
img_tokens = img.shape[1]
359+
timestep_zero_index = None
315360
if ref_latents is not None:
361+
ref_num_tokens = []
316362
h = 0
317363
w = 0
318364
index = 0
319365
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
366+
timestep_zero = ref_latents_method == "index_timestep_zero"
320367
for ref in ref_latents:
321-
if ref_latents_method == "index":
368+
if ref_latents_method in ("index", "index_timestep_zero"):
322369
index += self.params.ref_index_scale
323370
h_offset = 0
324371
w_offset = 0
@@ -342,13 +389,20 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
342389
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
343390
img = torch.cat([img, kontext], dim=1)
344391
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
392+
ref_num_tokens.append(kontext.shape[1])
393+
if timestep_zero:
394+
if index > 0:
395+
timestep = torch.cat([timestep, timestep * 0], dim=0)
396+
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
397+
transformer_options = transformer_options.copy()
398+
transformer_options["reference_image_num_tokens"] = ref_num_tokens
345399

346400
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
347401

348402
if len(self.params.txt_ids_dims) > 0:
349403
for i in self.params.txt_ids_dims:
350404
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
351405

352-
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
406+
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
353407
out = out[:, :img_tokens]
354408
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

comfy_extras/nodes_flux.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch
77
import math
88
import nodes
9+
import comfy.ldm.flux.math
910

1011
class CLIPTextEncodeFlux(io.ComfyNode):
1112
@classmethod
@@ -231,6 +232,68 @@ def execute(cls, steps, width, height) -> io.NodeOutput:
231232
sigmas = get_schedule(steps, round(seq_len))
232233
return io.NodeOutput(sigmas)
233234

235+
class KV_Attn_Input:
236+
def __init__(self):
237+
self.cache = {}
238+
239+
def __call__(self, q, k, v, extra_options, **kwargs):
240+
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
241+
if len(reference_image_num_tokens) == 0:
242+
return {}
243+
244+
ref_toks = sum(reference_image_num_tokens)
245+
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
246+
if cache_key in self.cache:
247+
kk, vv = self.cache[cache_key]
248+
self.set_cache = False
249+
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
250+
251+
self.cache[cache_key] = (k[:, :, -ref_toks:], v[:, :, -ref_toks:])
252+
self.set_cache = True
253+
return {"q": q, "k": k, "v": v}
254+
255+
def cleanup(self):
256+
self.cache = {}
257+
258+
259+
class FluxKVCache(io.ComfyNode):
260+
@classmethod
261+
def define_schema(cls) -> io.Schema:
262+
return io.Schema(
263+
node_id="FluxKVCache",
264+
display_name="Flux KV Cache",
265+
description="Enables KV Cache optimization for reference images on Flux family models.",
266+
category="",
267+
is_experimental=True,
268+
inputs=[
269+
io.Model.Input("model", tooltip="The model to use KV Cache on."),
270+
],
271+
outputs=[
272+
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
273+
],
274+
)
275+
276+
@classmethod
277+
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
278+
m = model.clone()
279+
input_patch_obj = KV_Attn_Input()
280+
281+
def model_input_patch(inputs):
282+
if len(input_patch_obj.cache) > 0:
283+
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
284+
if ref_image_tokens > 0:
285+
img = inputs["img"]
286+
inputs["img"] = img[:, :-ref_image_tokens]
287+
return inputs
288+
289+
m.set_model_attn1_patch(input_patch_obj)
290+
m.set_model_post_input_patch(model_input_patch)
291+
if hasattr(model.model.diffusion_model, "params"):
292+
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
293+
else:
294+
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
295+
296+
return io.NodeOutput(m)
234297

235298
class FluxExtension(ComfyExtension):
236299
@override
@@ -243,6 +306,7 @@ async def get_node_list(self) -> list[type[io.ComfyNode]]:
243306
FluxKontextMultiReferenceLatentMethod,
244307
EmptyFlux2LatentImage,
245308
Flux2Scheduler,
309+
FluxKVCache,
246310
]
247311

248312

0 commit comments

Comments
 (0)