Fix Flash Attention 3 interface for new FA3 return format#13173
Open
veeceey wants to merge 2 commits intohuggingface:mainfrom
Open
Fix Flash Attention 3 interface for new FA3 return format#13173veeceey wants to merge 2 commits intohuggingface:mainfrom
veeceey wants to merge 2 commits intohuggingface:mainfrom
Conversation
Newer versions of flash-attn (after Dao-AILab/flash-attention@ed20940) no longer return lse by default from flash_attn_3_func. The function now returns just the output tensor unless return_attn_probs=True is passed. Updated _wrapped_flash_attn_3 and _flash_varlen_attention_3 to pass return_attn_probs and handle both old (always tuple) and new (tensor or tuple) return formats gracefully. Fixes huggingface#12022
Author
Test ResultsCan't run FA3 tests locally (no CUDA GPU), but verified the logic:
The fix is backwards-compatible: old FA3 versions that always return tuples will still work since we check |
DN6
reviewed
Feb 23, 2026
Collaborator
DN6
left a comment
There was a problem hiding this comment.
Minor comment. Looks good otherwise 👍🏽
Comment on lines
699
to
704
| if isinstance(result, tuple): | ||
| out, lse, *_ = result | ||
| lse = lse.permute(0, 2, 1) | ||
| else: | ||
| out = result | ||
| lse = torch.empty(q.shape[0], q.shape[2], q.shape[1], device=q.device, dtype=torch.float32) |
Collaborator
There was a problem hiding this comment.
Don't think we need this guard. In both cases (old vs new FA3) we're always returning a tuple since return_attn_probs=True? Why not just leave as out, lse, *_ = flash_attn_3_func
Since return_attn_probs=True is always passed, the result is guaranteed to be a tuple. Remove the unnecessary isinstance guard.
DN6
approved these changes
Feb 24, 2026
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
After Dao-AILab/flash-attention@ed20940,
flash_attn_3_funcno longer returns(out, lse, ...)by default -- it just returnsout. This breaks_wrapped_flash_attn_3which unconditionally unpacksout, lse, *_:This PR:
return_attn_probs=Truetoflash_attn_3_func(consistent with how_flash_attention_3_hub_forward_opalready handles it)_flash_varlen_attention_3which had the same issueFixes #12022