Skip to content

Commit 0e25a69

Browse files
authored
Reduce video tiny VAE peak VRAM and decode time (CORE-127) (Comfy-Org#13617)
* Update taehv.py * Simplify * Simplify pixel_unshuffle dispatch
1 parent fce0398 commit 0e25a69

1 file changed

Lines changed: 22 additions & 15 deletions

File tree

comfy/taesd/taehv.py

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections import namedtuple, deque
88

99
import comfy.ops
10+
import comfy.model_management
1011
operations=comfy.ops.disable_weight_init
1112

1213
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
@@ -47,11 +48,14 @@ def forward(self, x):
4748
x = self.conv(x)
4849
return x.reshape(-1, C, H, W)
4950

50-
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
51+
def apply_model_with_memblocks(model, x, parallel, show_progress_bar, output_device=None,
52+
patch_size=1, decode=False):
5153

5254
B, T, C, H, W = x.shape
5355
if parallel:
5456
x = x.reshape(B*T, C, H, W)
57+
if not decode and patch_size > 1:
58+
x = F.pixel_unshuffle(x, patch_size)
5559
# parallel over input timesteps, iterate over blocks
5660
for b in tqdm(model, disable=not show_progress_bar):
5761
if isinstance(b, MemBlock):
@@ -62,20 +66,27 @@ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
6266
x = b(x, mem)
6367
else:
6468
x = b(x)
65-
BT, C, H, W = x.shape
66-
T = BT // B
67-
x = x.view(B, T, C, H, W)
69+
if decode and patch_size > 1:
70+
x = F.pixel_shuffle(x, patch_size)
71+
x = x.view(B, x.shape[0] // B, *x.shape[1:])
72+
x = x.to(output_device)
6873
else:
6974
out = []
70-
work_queue = deque([TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(B, T * C, H, W).chunk(T, dim=1))])
75+
# Chunk along the time dim directly (chunks are [B,1,C,H,W] views, squeeze to [B,C,H,W] views).
76+
# Avoids forcing a contiguous copy when x is non-contiguous (e.g. after movedim in encode/decode).
77+
work_queue = deque([TWorkItem(xt.squeeze(1), 0) for xt in x.chunk(T, dim=1)])
7178
progress_bar = tqdm(range(T), disable=not show_progress_bar)
7279
mem = [None] * len(model)
7380
while work_queue:
7481
xt, i = work_queue.popleft()
7582
if i == 0:
7683
progress_bar.update(1)
84+
if not decode and patch_size > 1:
85+
xt = F.pixel_unshuffle(xt, patch_size)
7786
if i == len(model):
78-
out.append(xt)
87+
if decode and patch_size > 1:
88+
xt = F.pixel_shuffle(xt, patch_size)
89+
out.append(xt.to(output_device))
7990
del xt
8091
else:
8192
b = model[i]
@@ -165,24 +176,20 @@ def show_progress_bar(self, value):
165176

166177
def encode(self, x, **kwargs):
167178
x = x.movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
168-
if self.patch_size > 1:
169-
B, T, C, H, W = x.shape
170-
x = x.reshape(B * T, C, H, W)
171-
x = F.pixel_unshuffle(x, self.patch_size)
172-
x = x.reshape(B, T, C * self.patch_size ** 2, H // self.patch_size, W // self.patch_size)
173179
if x.shape[1] % self.t_downscale != 0:
174180
# pad at end to multiple of t_downscale
175181
n_pad = self.t_downscale - x.shape[1] % self.t_downscale
176182
padding = x[:, -1:].repeat_interleave(n_pad, dim=1)
177183
x = torch.cat([x, padding], 1)
178-
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar).movedim(2, 1)
184+
x = apply_model_with_memblocks(self.encoder, x, self.parallel, self.show_progress_bar,
185+
patch_size=self.patch_size).movedim(2, 1)
179186
return self.process_out(x)
180187

181188
def decode(self, x, **kwargs):
182189
x = x.unsqueeze(0) if x.ndim == 4 else x # [T, C, H, W] -> [1, T, C, H, W]
183190
x = x.movedim(1, 2) if x.shape[1] != self.latent_channels else x # [B, T, C, H, W] or [B, C, T, H, W]
184191
x = self.process_in(x).movedim(2, 1) # [B, C, T, H, W] -> [B, T, C, H, W]
185-
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar)
186-
if self.patch_size > 1:
187-
x = F.pixel_shuffle(x, self.patch_size)
192+
x = apply_model_with_memblocks(self.decoder, x, self.parallel, self.show_progress_bar,
193+
output_device=comfy.model_management.intermediate_device(),
194+
patch_size=self.patch_size, decode=True)
188195
return x[:, self.frames_to_trim:].movedim(2, 1)

0 commit comments

Comments
 (0)