Skip to content

Comments

Fix Flash Attention 3 interface for new FA3 return format#13173

Open
veeceey wants to merge 2 commits intohuggingface:mainfrom
veeceey:fix/issue-12022-flash-attention-3-interface
Open

Fix Flash Attention 3 interface for new FA3 return format#13173
veeceey wants to merge 2 commits intohuggingface:mainfrom
veeceey:fix/issue-12022-flash-attention-3-interface

Conversation

@veeceey
Copy link

@veeceey veeceey commented Feb 23, 2026

After Dao-AILab/flash-attention@ed20940, flash_attn_3_func no longer returns (out, lse, ...) by default -- it just returns out. This breaks _wrapped_flash_attn_3 which unconditionally unpacks out, lse, *_:

ValueError: not enough values to unpack (expected at least 2, got 1)

This PR:

  • Passes return_attn_probs=True to flash_attn_3_func (consistent with how _flash_attention_3_hub_forward_op already handles it)
  • Adds a fallback for robustness in case the return format still varies
  • Applies the same fix to _flash_varlen_attention_3 which had the same issue

Fixes #12022

Newer versions of flash-attn (after Dao-AILab/flash-attention@ed20940)
no longer return lse by default from flash_attn_3_func. The function
now returns just the output tensor unless return_attn_probs=True is
passed.

Updated _wrapped_flash_attn_3 and _flash_varlen_attention_3 to pass
return_attn_probs and handle both old (always tuple) and new (tensor
or tuple) return formats gracefully.

Fixes huggingface#12022
@veeceey
Copy link
Author

veeceey commented Feb 23, 2026

Test Results

Can't run FA3 tests locally (no CUDA GPU), but verified the logic:

  1. _wrapped_flash_attn_3: Now passes return_attn_probs=True and handles both tuple (old FA3) and non-tuple (new FA3 fallback) returns
  2. _flash_varlen_attention_3: Same pattern, only requests return_attn_probs when return_lse=True is passed by the caller
  3. Consistent with how _flash_attention_3_hub_forward_op already handles this at line ~1258 (passes return_attn_probs=return_lse and conditionally unpacks)

The fix is backwards-compatible: old FA3 versions that always return tuples will still work since we check isinstance(result, tuple).

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment. Looks good otherwise 👍🏽

Comment on lines 699 to 704
if isinstance(result, tuple):
out, lse, *_ = result
lse = lse.permute(0, 2, 1)
else:
out = result
lse = torch.empty(q.shape[0], q.shape[2], q.shape[1], device=q.device, dtype=torch.float32)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't think we need this guard. In both cases (old vs new FA3) we're always returning a tuple since return_attn_probs=True? Why not just leave as out, lse, *_ = flash_attn_3_func

Since return_attn_probs=True is always passed, the result is
guaranteed to be a tuple. Remove the unnecessary isinstance guard.
Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @veeceey

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

_flash_attention_3 in dispatch_attention_fn is not compatible with the latest flash-atten interface.

3 participants