Skip to content

Commit 084e08c

Browse files
authored
Disable sageattention for SAM3 (Comfy-Org#13529)
Causes Nans
1 parent ef8f3cb commit 084e08c

3 files changed

Lines changed: 6 additions & 6 deletions

File tree

comfy/ldm/sam3/detector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def forward(self, q_input, k_input=None, v_input=None, mask=None):
5454
if mask is not None and mask.ndim == 2:
5555
mask = mask[:, None, None, :] # [B, T] -> [B, 1, 1, T] for SDPA broadcast
5656
dtype = q.dtype # manual_cast may produce mixed dtypes
57-
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask)
57+
out = optimized_attention(q, k.to(dtype), v.to(dtype), self.num_heads, mask=mask, low_precision_attention=False)
5858
return self.out_proj(out)
5959

6060

comfy/ldm/sam3/sam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def forward(self, q, k, v):
4040
q = self.q_proj(q)
4141
k = self.k_proj(k)
4242
v = self.v_proj(v)
43-
return self.out_proj(optimized_attention(q, k, v, self.num_heads))
43+
return self.out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
4444

4545

4646
class TwoWayAttentionBlock(nn.Module):
@@ -179,7 +179,7 @@ def forward(self, x, freqs_cis=None):
179179
q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0)
180180
if self.use_rope and freqs_cis is not None:
181181
q, k = apply_rope(q, k, freqs_cis)
182-
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True))
182+
return self.proj(optimized_attention(q, k, v, self.num_heads, skip_reshape=True, low_precision_attention=False))
183183

184184

185185
class Block(nn.Module):

comfy/ldm/sam3/tracker.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def forward(self, q, k=None, v=None, rope=None, num_k_exclude_rope=0):
364364
v = self.v_proj(v)
365365
if rope is not None:
366366
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
367-
out = optimized_attention(q, k, v, self.num_heads)
367+
out = optimized_attention(q, k, v, self.num_heads, low_precision_attention=False)
368368
return self.out_proj(out)
369369

370370

@@ -657,7 +657,7 @@ def forward(self, image, x, memory_image, memory, memory_image_pos=None,
657657
v = self.self_attn_v_proj(normed)
658658
if rope is not None:
659659
q, k = apply_rope_memory(q, k, rope, self.num_heads, 0)
660-
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
660+
x = x + self.self_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
661661

662662
# Decoupled cross-attention: fuse image and memory projections
663663
normed = self.norm2(x)
@@ -668,7 +668,7 @@ def forward(self, image, x, memory_image, memory, memory_image_pos=None,
668668
v = self.cross_attn_v_proj(memory)
669669
if rope is not None:
670670
q, k = apply_rope_memory(q, k, rope, self.num_heads, num_k_exclude_rope)
671-
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads))
671+
x = x + self.cross_attn_out_proj(optimized_attention(q, k, v, self.num_heads, low_precision_attention=False))
672672

673673
# FFN
674674
x = x + self.linear2(F.gelu(self.linear1(self.norm3(x))))

0 commit comments

Comments
 (0)