Skip to content

Commit c4fb027

Browse files
Add a way for nodes to add pre attn patches to flux model. (Comfy-Org#12861)
1 parent 740d998 commit c4fb027

3 files changed

Lines changed: 17 additions & 2 deletions

File tree

comfy/ldm/flux/layers.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,19 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
223223
del txt_k, img_k
224224
v = torch.cat((txt_v, img_v), dim=2)
225225
del txt_v, img_v
226+
227+
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
228+
if "attn1_patch" in transformer_patches:
229+
patch = transformer_patches["attn1_patch"]
230+
for p in patch:
231+
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
232+
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
233+
226234
# run actual attention
227235
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
228236
del q, k, v
229237

230238
if "attn1_output_patch" in transformer_patches:
231-
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
232239
patch = transformer_patches["attn1_output_patch"]
233240
for p in patch:
234241
attn = p(attn, extra_options)
@@ -321,6 +328,12 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
321328
del qkv
322329
q, k = self.norm(q, k, v)
323330

331+
if "attn1_patch" in transformer_patches:
332+
patch = transformer_patches["attn1_patch"]
333+
for p in patch:
334+
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
335+
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
336+
324337
# compute attention
325338
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
326339
del q, k, v

comfy/ldm/flux/math.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
3131

3232
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
3333
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
34+
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
35+
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
3436

3537
x_out = freqs_cis[..., 0] * x_[..., 0]
3638
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])

comfy/ldm/flux/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ def forward_orig(
170170

171171
if "post_input" in patches:
172172
for p in patches["post_input"]:
173-
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
173+
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
174174
img = out["img"]
175175
txt = out["txt"]
176176
img_ids = out["img_ids"]

0 commit comments

Comments
 (0)