[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692
Open
KshitijLakhani wants to merge 3 commits intoNVIDIA:mainfrom
Open
[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692KshitijLakhani wants to merge 3 commits intoNVIDIA:mainfrom
KshitijLakhani wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
… the TE constructed segment pos are not thereby causing mismatches in impl() Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Collaborator
Author
|
/te-ci jax L0 L1 L2 |
Contributor
Greptile SummaryThis PR fixes a shape mismatch bug in JAX's FusedAttn primitive batchers that occurs when using Key Changes:
Potential Improvements:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[User calls vmap-wrapped function] --> B[Function generates segment_pos internally]
B --> C[JAX traces FusedAttn batcher]
C --> D{Check: seg_id_bdim != None AND seg_pos_bdim == None?}
D -->|No| E[Pass args unchanged]
D -->|Yes| F[Detect dimension mismatch]
F --> G[For each segment_ids/pos pair]
G --> H[Calculate leading_bdim = segment_ids.ndim - segment_pos.ndim]
H --> I[Expand segment_pos with leading dimensions]
I --> J[Broadcast segment_pos to match segment_ids shape]
J --> K[Update batch_dims tuple]
K --> L[Call primitive with synchronized shapes]
E --> M[Primitive impl receives args]
L --> M
Last reviewed commit: da19f26 |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
What is the bug ?
TE provides a convenience function
from_segment_ids_and_pos()which allows users to pass only segment ids and the function returns aSequenceDescriptorwith internally generated segment pos and passed segment ids.As mentioned in Issue #2685 , if a user were to vmap a function forward() which i) accepts the q,k,v,segment ids and then ii) calls
from_segment_ids_and_pos()followed by iii) a call toDPA(), what happens is that JAX sees the segment ids as vmapped hence an extra leading dimension is added (e.g. 1,2,128) whereas the segment offsets are not given a leading dimension (e.g. 2,128). This results in the FusedAttn primitive impl() assert being triggered due to a shape mismatch between seg ids and seg pos as mentioned in issue #2685What is the root cause for the bug ?
On debugging, it can be seen that the shape starts differing when the batcher is being traced for the FusedAttn primitive.
segment_idsin the primitive: treated as vmapped inputs hence batched → (1, 2, 128).segment_posin the primitive: treated as derived within the function hence not batched → (2, 128).This PR aims to resolve this ensuring that
segment_poshas the same leading batching dims assegment_idsso the end user can vmap wrap the TE API calls without worrying about the batching in TE.Fixes #2685
Type of change
Changes
Ensure that the segment pos leading batch dims match that of segment ids in fused attn primitive's batcher
Checklist: