@@ -328,7 +328,7 @@ def forward(self, x, y):
328328 kv = torch .cat ((k , v ), dim = - 1 )
329329 split_size = kv .shape [- 1 ] // self .num_heads // 2
330330
331- kv = kv .view (1 , - 1 , self .num_heads , split_size * 2 )
331+ kv = kv .view (b , - 1 , self .num_heads , split_size * 2 )
332332 k , v = torch .split (kv , split_size , dim = - 1 )
333333
334334 q = q .view (b , s1 , self .num_heads , self .head_dim )
@@ -398,7 +398,7 @@ def forward(self, x):
398398 qkv_combined = torch .cat ((query , key , value ), dim = - 1 )
399399 split_size = qkv_combined .shape [- 1 ] // self .num_heads // 3
400400
401- qkv = qkv_combined .view (1 , - 1 , self .num_heads , split_size * 3 )
401+ qkv = qkv_combined .view (B , - 1 , self .num_heads , split_size * 3 )
402402 query , key , value = torch .split (qkv , split_size , dim = - 1 )
403403
404404 query = query .reshape (B , N , self .num_heads , self .head_dim )
@@ -607,9 +607,9 @@ def __init__(
607607 def forward (self , x , t , context , transformer_options = {}, ** kwargs ):
608608
609609 x = x .movedim (- 1 , - 2 )
610- uncond_emb , cond_emb = context .chunk ( 2 , dim = 0 )
611-
612- context = torch .cat ([cond_emb , uncond_emb ], dim = 0 )
610+ if context .shape [ 0 ] >= 2 :
611+ uncond_emb , cond_emb = context . chunk ( 2 , dim = 0 )
612+ context = torch .cat ([cond_emb , uncond_emb ], dim = 0 )
613613 main_condition = context
614614
615615 t = 1.0 - t
@@ -657,5 +657,8 @@ def block_wrap(args):
657657 output = self .final_layer (combined )
658658 output = output .movedim (- 2 , - 1 ) * (- 1.0 )
659659
660- cond_emb , uncond_emb = output .chunk (2 , dim = 0 )
661- return torch .cat ([uncond_emb , cond_emb ])
660+ if output .shape [0 ] >= 2 :
661+ cond_emb , uncond_emb = output .chunk (2 , dim = 0 )
662+ return torch .cat ([uncond_emb , cond_emb ])
663+ else :
664+ return output
0 commit comments