77from collections import namedtuple , deque
88
99import comfy .ops
10+ import comfy .model_management
1011operations = comfy .ops .disable_weight_init
1112
1213DecoderResult = namedtuple ("DecoderResult" , ("frame" , "memory" ))
@@ -47,11 +48,14 @@ def forward(self, x):
4748 x = self .conv (x )
4849 return x .reshape (- 1 , C , H , W )
4950
50- def apply_model_with_memblocks (model , x , parallel , show_progress_bar ):
51+ def apply_model_with_memblocks (model , x , parallel , show_progress_bar , output_device = None ,
52+ patch_size = 1 , decode = False ):
5153
5254 B , T , C , H , W = x .shape
5355 if parallel :
5456 x = x .reshape (B * T , C , H , W )
57+ if not decode and patch_size > 1 :
58+ x = F .pixel_unshuffle (x , patch_size )
5559 # parallel over input timesteps, iterate over blocks
5660 for b in tqdm (model , disable = not show_progress_bar ):
5761 if isinstance (b , MemBlock ):
@@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
6266 x = b (x , mem )
6367 else :
6468 x = b (x )
65- BT , C , H , W = x .shape
66- T = BT // B
67- x = x .view (B , T , C , H , W )
69+ if decode and patch_size > 1 :
70+ x = F .pixel_shuffle (x , patch_size )
71+ x = x .view (B , x .shape [0 ] // B , * x .shape [1 :])
72+ x = x .to (output_device )
6873 else :
6974 out = []
70- work_queue = deque ([TWorkItem (xt , 0 ) for t , xt in enumerate (x .reshape (B , T * C , H , W ).chunk (T , dim = 1 ))])
75+ # Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views).
76+ # Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode).
77+ work_queue = deque ([TWorkItem (xt .squeeze (1 ), 0 ) for xt in x .chunk (T , dim = 1 )])
7178 progress_bar = tqdm (range (T ), disable = not show_progress_bar )
7279 mem = [None ] * len (model )
7380 while work_queue :
7481 xt , i = work_queue .popleft ()
7582 if i == 0 :
7683 progress_bar .update (1 )
84+ if not decode and patch_size > 1 :
85+ xt = F .pixel_unshuffle (xt , patch_size )
7786 if i == len (model ):
78- out .append (xt )
87+ if decode and patch_size > 1 :
88+ xt = F .pixel_shuffle (xt , patch_size )
89+ out .append (xt .to (output_device ))
7990 del xt
8091 else :
8192 b = model [i ]
@@ -165,24 +176,20 @@ def show_progress_bar(self, value):
165176
166177 def encode (self , x , ** kwargs ):
167178 x = x .movedim (2 , 1 ) # [B, C, T, H, W] -> [B, T, C, H, W]
168- if self .patch_size > 1 :
169- B , T , C , H , W = x .shape
170- x = x .reshape (B * T , C , H , W )
171- x = F .pixel_unshuffle (x , self .patch_size )
172- x = x .reshape (B , T , C * self .patch_size ** 2 , H // self .patch_size , W // self .patch_size )
173179 if x .shape [1 ] % self .t_downscale != 0 :
174180 # pad at end to multiple of t_downscale
175181 n_pad = self .t_downscale - x .shape [1 ] % self .t_downscale
176182 padding = x [:, - 1 :].repeat_interleave (n_pad , dim = 1 )
177183 x = torch .cat ([x , padding ], 1 )
178- x = apply_model_with_memblocks (self .encoder , x , self .parallel , self .show_progress_bar ).movedim (2 , 1 )
184+ x = apply_model_with_memblocks (self .encoder , x , self .parallel , self .show_progress_bar ,
185+ patch_size = self .patch_size ).movedim (2 , 1 )
179186 return self .process_out (x )
180187
181188 def decode (self , x , ** kwargs ):
182189 x = x .unsqueeze (0 ) if x .ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
183190 x = x .movedim (1 , 2 ) if x .shape [1 ] != self .latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
184191 x = self .process_in (x ).movedim (2 , 1 ) # [B, C, T, H, W] -> [B, T, C, H, W]
185- x = apply_model_with_memblocks (self .decoder , x , self .parallel , self .show_progress_bar )
186- if self . patch_size > 1 :
187- x = F . pixel_shuffle ( x , self .patch_size )
192+ x = apply_model_with_memblocks (self .decoder , x , self .parallel , self .show_progress_bar ,
193+ output_device = comfy . model_management . intermediate_device (),
194+ patch_size = self .patch_size , decode = True )
188195 return x [:, self .frames_to_trim :].movedim (2 , 1 )
0 commit comments