@@ -44,6 +44,22 @@ class FluxParams:
4444 txt_norm : bool = False
4545
4646
47+ def invert_slices (slices , length ):
48+ sorted_slices = sorted (slices )
49+ result = []
50+ current = 0
51+
52+ for start , end in sorted_slices :
53+ if current < start :
54+ result .append ((current , start ))
55+ current = max (current , end )
56+
57+ if current < length :
58+ result .append ((current , length ))
59+
60+ return result
61+
62+
4763class Flux (nn .Module ):
4864 """
4965 Transformer model for flow matching on sequences.
@@ -138,6 +154,7 @@ def forward_orig(
138154 y : Tensor ,
139155 guidance : Tensor = None ,
140156 control = None ,
157+ timestep_zero_index = None ,
141158 transformer_options = {},
142159 attn_mask : Tensor = None ,
143160 ) -> Tensor :
@@ -164,10 +181,6 @@ def forward_orig(
164181 txt = self .txt_norm (txt )
165182 txt = self .txt_in (txt )
166183
167- vec_orig = vec
168- if self .params .global_modulation :
169- vec = (self .double_stream_modulation_img (vec_orig ), self .double_stream_modulation_txt (vec_orig ))
170-
171184 if "post_input" in patches :
172185 for p in patches ["post_input" ]:
173186 out = p ({"img" : img , "txt" : txt , "img_ids" : img_ids , "txt_ids" : txt_ids , "transformer_options" : transformer_options })
@@ -182,6 +195,24 @@ def forward_orig(
182195 else :
183196 pe = None
184197
198+ vec_orig = vec
199+ txt_vec = vec
200+ extra_kwargs = {}
201+ if timestep_zero_index is not None :
202+ modulation_dims = []
203+ batch = vec .shape [0 ] // 2
204+ vec_orig = vec_orig .reshape (2 , batch , vec .shape [1 ]).movedim (0 , 1 )
205+ invert = invert_slices (timestep_zero_index , img .shape [1 ])
206+ for s in invert :
207+ modulation_dims .append ((s [0 ], s [1 ], 0 ))
208+ for s in timestep_zero_index :
209+ modulation_dims .append ((s [0 ], s [1 ], 1 ))
210+ extra_kwargs ["modulation_dims_img" ] = modulation_dims
211+ txt_vec = vec [:batch ]
212+
213+ if self .params .global_modulation :
214+ vec = (self .double_stream_modulation_img (vec_orig ), self .double_stream_modulation_txt (txt_vec ))
215+
185216 blocks_replace = patches_replace .get ("dit" , {})
186217 transformer_options ["total_blocks" ] = len (self .double_blocks )
187218 transformer_options ["block_type" ] = "double"
@@ -195,7 +226,8 @@ def block_wrap(args):
195226 vec = args ["vec" ],
196227 pe = args ["pe" ],
197228 attn_mask = args .get ("attn_mask" ),
198- transformer_options = args .get ("transformer_options" ))
229+ transformer_options = args .get ("transformer_options" ),
230+ ** extra_kwargs )
199231 return out
200232
201233 out = blocks_replace [("double_block" , i )]({"img" : img ,
@@ -213,7 +245,8 @@ def block_wrap(args):
213245 vec = vec ,
214246 pe = pe ,
215247 attn_mask = attn_mask ,
216- transformer_options = transformer_options )
248+ transformer_options = transformer_options ,
249+ ** extra_kwargs )
217250
218251 if control is not None : # Controlnet
219252 control_i = control .get ("input" )
@@ -230,6 +263,12 @@ def block_wrap(args):
230263 if self .params .global_modulation :
231264 vec , _ = self .single_stream_modulation (vec_orig )
232265
266+ extra_kwargs = {}
267+ if timestep_zero_index is not None :
268+ lambda a : 0 if a == 0 else a + txt .shape [1 ]
269+ modulation_dims_combined = list (map (lambda x : (0 if x [0 ] == 0 else x [0 ] + txt .shape [1 ], x [1 ] + txt .shape [1 ], x [2 ]), modulation_dims ))
270+ extra_kwargs ["modulation_dims" ] = modulation_dims_combined
271+
233272 transformer_options ["total_blocks" ] = len (self .single_blocks )
234273 transformer_options ["block_type" ] = "single"
235274 transformer_options ["img_slice" ] = [txt .shape [1 ], img .shape [1 ]]
@@ -242,7 +281,8 @@ def block_wrap(args):
242281 vec = args ["vec" ],
243282 pe = args ["pe" ],
244283 attn_mask = args .get ("attn_mask" ),
245- transformer_options = args .get ("transformer_options" ))
284+ transformer_options = args .get ("transformer_options" ),
285+ ** extra_kwargs )
246286 return out
247287
248288 out = blocks_replace [("single_block" , i )]({"img" : img ,
@@ -253,7 +293,7 @@ def block_wrap(args):
253293 {"original_block" : block_wrap })
254294 img = out ["img" ]
255295 else :
256- img = block (img , vec = vec , pe = pe , attn_mask = attn_mask , transformer_options = transformer_options )
296+ img = block (img , vec = vec , pe = pe , attn_mask = attn_mask , transformer_options = transformer_options , ** extra_kwargs )
257297
258298 if control is not None : # Controlnet
259299 control_o = control .get ("output" )
@@ -264,7 +304,11 @@ def block_wrap(args):
264304
265305 img = img [:, txt .shape [1 ] :, ...]
266306
267- img = self .final_layer (img , vec_orig ) # (N, T, patch_size ** 2 * out_channels)
307+ extra_kwargs = {}
308+ if timestep_zero_index is not None :
309+ extra_kwargs ["modulation_dims" ] = modulation_dims
310+
311+ img = self .final_layer (img , vec_orig , ** extra_kwargs ) # (N, T, patch_size ** 2 * out_channels)
268312 return img
269313
270314 def process_img (self , x , index = 0 , h_offset = 0 , w_offset = 0 , transformer_options = {}):
@@ -312,13 +356,16 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
312356 w_len = ((w_orig + (patch_size // 2 )) // patch_size )
313357 img , img_ids = self .process_img (x , transformer_options = transformer_options )
314358 img_tokens = img .shape [1 ]
359+ timestep_zero_index = None
315360 if ref_latents is not None :
361+ ref_num_tokens = []
316362 h = 0
317363 w = 0
318364 index = 0
319365 ref_latents_method = kwargs .get ("ref_latents_method" , self .params .default_ref_method )
366+ timestep_zero = ref_latents_method == "index_timestep_zero"
320367 for ref in ref_latents :
321- if ref_latents_method == "index" :
368+ if ref_latents_method in ( "index" , "index_timestep_zero" ) :
322369 index += self .params .ref_index_scale
323370 h_offset = 0
324371 w_offset = 0
@@ -342,13 +389,20 @@ def _forward(self, x, timestep, context, y=None, guidance=None, ref_latents=None
342389 kontext , kontext_ids = self .process_img (ref , index = index , h_offset = h_offset , w_offset = w_offset )
343390 img = torch .cat ([img , kontext ], dim = 1 )
344391 img_ids = torch .cat ([img_ids , kontext_ids ], dim = 1 )
392+ ref_num_tokens .append (kontext .shape [1 ])
393+ if timestep_zero :
394+ if index > 0 :
395+ timestep = torch .cat ([timestep , timestep * 0 ], dim = 0 )
396+ timestep_zero_index = [[img_tokens , img_ids .shape [1 ]]]
397+ transformer_options = transformer_options .copy ()
398+ transformer_options ["reference_image_num_tokens" ] = ref_num_tokens
345399
346400 txt_ids = torch .zeros ((bs , context .shape [1 ], len (self .params .axes_dim )), device = x .device , dtype = torch .float32 )
347401
348402 if len (self .params .txt_ids_dims ) > 0 :
349403 for i in self .params .txt_ids_dims :
350404 txt_ids [:, :, i ] = torch .linspace (0 , context .shape [1 ] - 1 , steps = context .shape [1 ], device = x .device , dtype = torch .float32 )
351405
352- out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance , control , transformer_options , attn_mask = kwargs .get ("attention_mask" , None ))
406+ out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance , control , timestep_zero_index = timestep_zero_index , transformer_options = transformer_options , attn_mask = kwargs .get ("attention_mask" , None ))
353407 out = out [:, :img_tokens ]
354408 return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = self .patch_size , pw = self .patch_size )[:,:,:h_orig ,:w_orig ]
0 commit comments