@@ -364,7 +364,7 @@ def forward(self, q, k=None, v=None, rope=None, num_k_exclude_rope=0):
364364 v = self .v_proj (v )
365365 if rope is not None :
366366 q , k = apply_rope_memory (q , k , rope , self .num_heads , num_k_exclude_rope )
367- out = optimized_attention (q , k , v , self .num_heads )
367+ out = optimized_attention (q , k , v , self .num_heads , low_precision_attention = False )
368368 return self .out_proj (out )
369369
370370
@@ -657,7 +657,7 @@ def forward(self, image, x, memory_image, memory, memory_image_pos=None,
657657 v = self .self_attn_v_proj (normed )
658658 if rope is not None :
659659 q , k = apply_rope_memory (q , k , rope , self .num_heads , 0 )
660- x = x + self .self_attn_out_proj (optimized_attention (q , k , v , self .num_heads ))
660+ x = x + self .self_attn_out_proj (optimized_attention (q , k , v , self .num_heads , low_precision_attention = False ))
661661
662662 # Decoupled cross-attention: fuse image and memory projections
663663 normed = self .norm2 (x )
@@ -668,7 +668,7 @@ def forward(self, image, x, memory_image, memory, memory_image_pos=None,
668668 v = self .cross_attn_v_proj (memory )
669669 if rope is not None :
670670 q , k = apply_rope_memory (q , k , rope , self .num_heads , num_k_exclude_rope )
671- x = x + self .cross_attn_out_proj (optimized_attention (q , k , v , self .num_heads ))
671+ x = x + self .cross_attn_out_proj (optimized_attention (q , k , v , self .num_heads , low_precision_attention = False ))
672672
673673 # FFN
674674 x = x + self .linear2 (F .gelu (self .linear1 (self .norm3 (x ))))
0 commit comments