@@ -118,8 +118,6 @@ def forward(self, x: torch.Tensor, attention_mask: torch.Tensor = None, image_ro
118118 query = apply_rotary_emb (query , image_rotary_emb )
119119 key = apply_rotary_emb (key , image_rotary_emb )
120120
121- query , key = query .to (x .dtype ), key .to (x .dtype )
122-
123121 q_flat = query .reshape (B , S , - 1 )
124122 k_flat = key .reshape (B , S , - 1 )
125123
@@ -161,16 +159,16 @@ def forward(self, x, rotary_pos_emb, temb, attention_mask=None):
161159
162160 residual = x
163161 x_norm = self .adaLN_sa_ln (x )
164- x_norm = ( x_norm . float () * (1 + scale_msa . float ()) + shift_msa . float ()). to ( x . dtype )
162+ x_norm = x_norm * (1 + scale_msa ) + shift_msa
165163
166164 attn_out = self .self_attention (x_norm , attention_mask = attention_mask , image_rotary_emb = rotary_pos_emb )
167- x = residual + ( gate_msa . float () * attn_out . float ()). to ( x . dtype )
165+ x = residual + gate_msa * attn_out
168166
169167 residual = x
170168 x_norm = self .adaLN_mlp_ln (x )
171- x_norm = ( x_norm . float () * (1 + scale_mlp . float ()) + shift_mlp . float ()). to ( x . dtype )
169+ x_norm = x_norm * (1 + scale_mlp ) + shift_mlp
172170
173- return residual + ( gate_mlp . float () * self .mlp (x_norm ). float ()). to ( x . dtype )
171+ return residual + gate_mlp * self .mlp (x_norm )
174172
175173class ErnieImageAdaLNContinuous (nn .Module ):
176174 def __init__ (self , hidden_size : int , eps : float = 1e-6 , operations = None , device = None , dtype = None ):
@@ -183,7 +181,7 @@ def __init__(self, hidden_size: int, eps: float = 1e-6, operations=None, device=
183181 def forward (self , x : torch .Tensor , conditioning : torch .Tensor ) -> torch .Tensor :
184182 scale , shift = self .linear (conditioning ).chunk (2 , dim = - 1 )
185183 x = self .norm (x )
186- x = x * ( 1 + scale .unsqueeze (1 )) + shift .unsqueeze (1 )
184+ x = torch . addcmul ( shift .unsqueeze (1 ), x , 1 + scale .unsqueeze (1 ) )
187185 return x
188186
189187class ErnieImageModel (nn .Module ):
0 commit comments