Skip to content

Commit 3d816db

Browse files
Some optimizations to make Ernie inference a bit faster. (Comfy-Org#13472)
1 parent b9dedea commit 3d816db

1 file changed

Lines changed: 5 additions & 7 deletions

File tree

comfy/ldm/ernie/model.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,6 @@ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_ro
118118
query = apply_rotary_emb(query, image_rotary_emb)
119119
key = apply_rotary_emb(key, image_rotary_emb)
120120

121-
query, key = query.to(x.dtype), key.to(x.dtype)
122-
123121
q_flat = query.reshape(B, S, -1)
124122
k_flat = key.reshape(B, S, -1)
125123

@@ -161,16 +159,16 @@ def forward(self, x, rotary_pos_emb, temb, attention_mask=None):
161159

162160
residual = x
163161
x_norm = self.adaLN_sa_ln(x)
164-
x_norm = (x_norm.float() * (1 + scale_msa.float()) + shift_msa.float()).to(x.dtype)
162+
x_norm = x_norm * (1 + scale_msa) + shift_msa
165163

166164
attn_out = self.self_attention(x_norm, attention_mask=attention_mask, image_rotary_emb=rotary_pos_emb)
167-
x = residual + (gate_msa.float() * attn_out.float()).to(x.dtype)
165+
x = residual + gate_msa * attn_out
168166

169167
residual = x
170168
x_norm = self.adaLN_mlp_ln(x)
171-
x_norm = (x_norm.float() * (1 + scale_mlp.float()) + shift_mlp.float()).to(x.dtype)
169+
x_norm = x_norm * (1 + scale_mlp) + shift_mlp
172170

173-
return residual + (gate_mlp.float() * self.mlp(x_norm).float()).to(x.dtype)
171+
return residual + gate_mlp * self.mlp(x_norm)
174172

175173
class ErnieImageAdaLNContinuous(nn.Module):
176174
def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=None, dtype=None):
@@ -183,7 +181,7 @@ def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=
183181
def forward(self, x: torch.Tensor, conditioning: torch.Tensor) -> torch.Tensor:
184182
scale, shift = self.linear(conditioning).chunk(2, dim=-1)
185183
x = self.norm(x)
186-
x = x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
184+
x = torch.addcmul(shift.unsqueeze(1), x, 1 + scale.unsqueeze(1))
187185
return x
188186

189187
class ErnieImageModel(nn.Module):

0 commit comments

Comments
 (0)