Skip to content

Commit 78b5dec

Browse files
authored
fix: Hunyuan3D 2.1 batch size crashes in attention and forward pass (Comfy-Org#13699)
1 parent 72e3f60 commit 78b5dec

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

comfy/ldm/hunyuan3dv2_1/hunyuandit.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def forward(self, x, y):
328328
kv = torch.cat((k, v), dim=-1)
329329
split_size = kv.shape[-1] // self.num_heads // 2
330330

331-
kv = kv.view(1, -1, self.num_heads, split_size * 2)
331+
kv = kv.view(b, -1, self.num_heads, split_size * 2)
332332
k, v = torch.split(kv, split_size, dim=-1)
333333

334334
q = q.view(b, s1, self.num_heads, self.head_dim)
@@ -398,7 +398,7 @@ def forward(self, x):
398398
qkv_combined = torch.cat((query, key, value), dim=-1)
399399
split_size = qkv_combined.shape[-1] // self.num_heads // 3
400400

401-
qkv = qkv_combined.view(1, -1, self.num_heads, split_size * 3)
401+
qkv = qkv_combined.view(B, -1, self.num_heads, split_size * 3)
402402
query, key, value = torch.split(qkv, split_size, dim=-1)
403403

404404
query = query.reshape(B, N, self.num_heads, self.head_dim)
@@ -607,9 +607,9 @@ def __init__(
607607
def forward(self, x, t, context, transformer_options = {}, **kwargs):
608608

609609
x = x.movedim(-1, -2)
610-
uncond_emb, cond_emb = context.chunk(2, dim = 0)
611-
612-
context = torch.cat([cond_emb, uncond_emb], dim = 0)
610+
if context.shape[0] >= 2:
611+
uncond_emb, cond_emb = context.chunk(2, dim = 0)
612+
context = torch.cat([cond_emb, uncond_emb], dim = 0)
613613
main_condition = context
614614

615615
t = 1.0 - t
@@ -657,5 +657,8 @@ def block_wrap(args):
657657
output = self.final_layer(combined)
658658
output = output.movedim(-2, -1) * (-1.0)
659659

660-
cond_emb, uncond_emb = output.chunk(2, dim = 0)
661-
return torch.cat([uncond_emb, cond_emb])
660+
if output.shape[0] >= 2:
661+
cond_emb, uncond_emb = output.chunk(2, dim = 0)
662+
return torch.cat([uncond_emb, cond_emb])
663+
else:
664+
return output

0 commit comments

Comments
 (0)