diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index 90ffcac80dc5..d6a4fd019b5b 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -676,7 +676,7 @@ def _wrapped_flash_attn_3( ) -> tuple[torch.Tensor, torch.Tensor]: # Hardcoded for now because pytorch does not support tuple/int type hints window_size = (-1, -1) - out, lse, *_ = flash_attn_3_func( + result = flash_attn_3_func( q=q, k=k, v=v, @@ -693,7 +693,9 @@ def _wrapped_flash_attn_3( pack_gqa=pack_gqa, deterministic=deterministic, sm_margin=sm_margin, + return_attn_probs=True, ) + out, lse, *_ = result lse = lse.permute(0, 2, 1) return out, lse @@ -2623,7 +2625,7 @@ def _flash_varlen_attention_3( key_packed = torch.cat(key_valid, dim=0) value_packed = torch.cat(value_valid, dim=0) - out, lse, *_ = flash_attn_3_varlen_func( + result = flash_attn_3_varlen_func( q=query_packed, k=key_packed, v=value_packed, @@ -2633,7 +2635,13 @@ def _flash_varlen_attention_3( max_seqlen_k=max_seqlen_k, softmax_scale=scale, causal=is_causal, + return_attn_probs=return_lse, ) + if isinstance(result, tuple): + out, lse, *_ = result + else: + out = result + lse = None out = out.unflatten(0, (batch_size, -1)) return (out, lse) if return_lse else out