add SP support for flash_varlen_hub backend#13479
add SP support for flash_varlen_hub backend#13479zhtmike wants to merge 18 commits intohuggingface:mainfrom
flash_varlen_hub backend#13479Conversation
|
code snippet to show it works |
|
Hi @sayakpaul, the PR is ready for review, please take a look once you have time |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks a lot for the PR! I left some comments, LMK what you think.
Should it be propagated to FA3, too, perhaps in a different PR?
| try: | ||
| from flash_attn import flash_attn_func, flash_attn_varlen_func | ||
| from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward | ||
| from flash_attn.flash_attn_interface import ( |
There was a problem hiding this comment.
WDYT of constraining the changes only to FLASH_HUB?
This way, people won't have to build the flash attention wheel locally.
There was a problem hiding this comment.
We will be deprecating the non-Hub variants for FLASH and FLASH_3` soonish anyway.
There was a problem hiding this comment.
OK let me move to the hub version of flash attention 2 then
There was a problem hiding this comment.
Done. Moved to FLASH_HUB.
|
|
||
|
|
||
| @dataclass | ||
| class _VarlenPackedInputs: |
There was a problem hiding this comment.
Does it apply to all varlen attention kernels, though? Or does it come to fruition only during CP?
We do have VARLEN implementations of a few backends already:
There was a problem hiding this comment.
Any reason to use dataclasses for this? Won't it be better to apply the transformations inline for readability?
There was a problem hiding this comment.
We do have VARLEN implementations of a few backends already:
Yes but seems all VARLEN implementation only works with non-CP case. So it is fine to work with transformation inline for non-CP.
Any reason to use dataclasses for this? Won't it be better to apply the transformations inline for readability?
Let me drop dataclasses, it is just my personal perference
| return packed_out.view(self.batch_size, self.seq_len_q, *packed_out.shape[1:]) | ||
|
|
||
|
|
||
| def _pack_qkv( |
There was a problem hiding this comment.
Why do we need this function if we decided to do the transformations in-line?
|
|
||
| if attn_mask is not None: | ||
| if return_lse: | ||
| raise NotImplementedError("`return_lse=True` with `attn_mask` is not yet supported for flash-attn 2.") |
There was a problem hiding this comment.
That means Ring isn't supported yet right?
There was a problem hiding this comment.
Yes. Not supported yet.
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from diffusers.models.attention_dispatch import ( |
There was a problem hiding this comment.
Added test coverage. The xfail for _flash_3_hub will be dropped once flash_3 with attention_mask is added
Result:
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[True-native-ulysses_degree] PASSED
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[True-native-ring_degree] SKIPPED
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[True-flash_hub-ulysses_degree] PASSED
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[True-flash_hub-ring_degree] XFAIL
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[True-_flash_3_hub-ulysses_degree] XFAIL
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[True-_flash_3_hub-ring_degree] XFAIL
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[False-native-ulysses_degree] PASSED
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[False-native-ring_degree] SKIPPED
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[False-flash_hub-ulysses_degree] PASSED
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[False-flash_hub-ring_degree] XFAIL
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[False-_flash_3_hub-ulysses_degree] XFAIL
tests/models/transformers/test_models_transformer_qwenimage.py::TestQwenImageTransformerContextParallelAttnBackends::test_context_parallel_attn_backend_inference[False-_flash_3_hub-ring_degree] XFAIL
Agree. I will implement FA3 as well, in different PR |
|
@claude could you do a review? |
|
Claude finished @sayakpaul's task in 3m 25s —— View job Review of PR #13479: Add attention mask input support for flash backend
Overall: The approach is correct — using I have a few comments: 1. Missing None-check on
|
sayakpaul
left a comment
There was a problem hiding this comment.
IIUC we have to rely on varlen if attention mask is specified. If that's the case we should rather specify
| wrapped_forward_fn: Callable | None = None | ||
| wrapped_backward_fn: Callable | None = None | ||
| # Some backends (e.g. flash attention) have separate kernels for variable-length inputs | ||
| varlen_function_attr: str | None = None |
There was a problem hiding this comment.
Instead of introducing a new attribute for varlen, I think we should do something similar to:
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Agree. Let me implement CP with varlen kernel instead. It will look cleaner. |
flash_varlen_hub backend
Hi @sayakpaul, I have reworked the CP with the varlen kernel. Now QwenImagePipeline supports CP with I have tested with And previous Can you take a look? Thanks! |
sayakpaul
left a comment
There was a problem hiding this comment.
Thanks for the refactor. I don't really understand some of the big changes to the existing codebase. So, please provide reasoning behind them.
| def _padded_to_unpad(tensor: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: | ||
| """gather valid tokens from a padded `(batch, seq, ...)` tensor into a packed `(nnz, ...)` tensor.""" | ||
| return tensor.reshape(-1, *tensor.shape[2:])[indices] |
There was a problem hiding this comment.
This is just a one-liner utility. Let's use it directly in the caller sites.
| (_, seqlens_k), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen( | ||
| batch_size, seq_len_q, seq_len_kv, attn_mask=attn_mask, device=query.device | ||
| if _parallel_config is not None: |
There was a problem hiding this comment.
Let's follow this pattern:
| if attn_mask is not None: | ||
| attn_mask_2d = _normalize_attn_mask(attn_mask, batch_size, seq_len_kv) | ||
| (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_with_mask(batch_size, seq_len_q, attn_mask_2d, query.device) | ||
| ) | ||
| indices_k = attn_mask_2d.flatten().nonzero(as_tuple=False).flatten() | ||
| key_packed = _padded_to_unpad(key, indices_k) | ||
| value_packed = _padded_to_unpad(value, indices_k) | ||
| else: | ||
| (_, _), (cu_seqlens_q, cu_seqlens_k), (max_seqlen_q, max_seqlen_k) = ( | ||
| _prepare_for_flash_attn_or_sage_varlen_without_mask(batch_size, seq_len_q, seq_len_kv, query.device) | ||
| ) | ||
| key_packed = key.flatten(0, 1) | ||
| value_packed = value.flatten(0, 1) |
There was a problem hiding this comment.
for example, assume batch_size=2, seq_len_kv=4
two branches:
-
if there is mask like
batch 0: [T, T, T, F] ← 3 real tokens batch 1: [T, T, F, F] ← 2 real tokens- normalizes 4D-mask to 2D-mask [batch, seq_kv] if necessary
- computes
cu_seqlens_k = [0, 3, 5](cumulative token counts: 0 → 3 → 5) - finds
indices_k = [0, 1, 2, 4, 5](the flat indices of the True positions) - gathers only those rows ->
key_packedwith shape (5, heads, dim)
-
if there is no mask
- computes
cu_seqlens_k = [0, 4, 8] - finds
indices_k = [0, 1, 2, 3, 4, 5, 6, 7] key_packedwith shape (8, heads, dim)
- computes
Then feed them into varlen kernel.
It is a vectorized way of handling key/value packing compared with old for-loop method. And it can also handle QwenImage-like mask of [T, T, T, F, F, T], the old way cannot.
| marks=pytest.mark.skipif(not is_kernels_available(), reason="`kernels` is not available."), | ||
| ), | ||
| pytest.param( | ||
| "flash_varlen_hub", |
There was a problem hiding this comment.
Should varlen tests get their own testing mixin class?
There was a problem hiding this comment.
I think the varlen kernel can handle all the cases supported by the non-varlen kernel. Personally, I prefer to put them together.
| """Context Parallel inference x attention backends tests for QwenImage Transformer""" | ||
|
|
||
| # flash_hub and _flash_3_hub do not support attn_mask | ||
| unsupported_attn_backends = ["flash_hub", "_flash_3_hub"] |
There was a problem hiding this comment.
Any not varlen attention backend would fail no? If so, I would rather do something like
if "varlen" not in attention_backend:
pytest.skip(...)There was a problem hiding this comment.
like FluxPipeline, it can also support varlen kernels after this change.
I’m not sure what the most suitable place is to put this
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>


What does this PR do?
This PR adds support for attention mask input when using the attention backend with
set_attention_backend("flash"). With this change,QwenImagePipelinecan run with the flash backend w/ or w/oUlysses SP.For FlashAttention 2, it is not feasible to use
_wrapped_flash_attn_forwarddirectly when a mask is applied. To maintain compatibility with the current interface, we introduce an additional branch for FlashAttention to handle attention masks.I haven't tested with ring attention, so it is left as unimplemented.
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.