@@ -223,12 +223,19 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
223223 del txt_k , img_k
224224 v = torch .cat ((txt_v , img_v ), dim = 2 )
225225 del txt_v , img_v
226+
227+ extra_options ["img_slice" ] = [txt .shape [1 ], q .shape [2 ]]
228+ if "attn1_patch" in transformer_patches :
229+ patch = transformer_patches ["attn1_patch" ]
230+ for p in patch :
231+ out = p (q , k , v , pe = pe , attn_mask = attn_mask , extra_options = extra_options )
232+ q , k , v , pe , attn_mask = out .get ("q" , q ), out .get ("k" , k ), out .get ("v" , v ), out .get ("pe" , pe ), out .get ("attn_mask" , attn_mask )
233+
226234 # run actual attention
227235 attn = attention (q , k , v , pe = pe , mask = attn_mask , transformer_options = transformer_options )
228236 del q , k , v
229237
230238 if "attn1_output_patch" in transformer_patches :
231- extra_options ["img_slice" ] = [txt .shape [1 ], attn .shape [1 ]]
232239 patch = transformer_patches ["attn1_output_patch" ]
233240 for p in patch :
234241 attn = p (attn , extra_options )
@@ -321,6 +328,12 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, attn_mask=None, modulation
321328 del qkv
322329 q , k = self .norm (q , k , v )
323330
331+ if "attn1_patch" in transformer_patches :
332+ patch = transformer_patches ["attn1_patch" ]
333+ for p in patch :
334+ out = p (q , k , v , pe = pe , attn_mask = attn_mask , extra_options = extra_options )
335+ q , k , v , pe , attn_mask = out .get ("q" , q ), out .get ("k" , k ), out .get ("v" , v ), out .get ("pe" , pe ), out .get ("attn_mask" , attn_mask )
336+
324337 # compute attention
325338 attn = attention (q , k , v , pe = pe , mask = attn_mask , transformer_options = transformer_options )
326339 del q , k , v
0 commit comments