Skip to content

Comments

[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692

Open
KshitijLakhani wants to merge 3 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/vmap-get-seg-ids-pos
Open

[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692
KshitijLakhani wants to merge 3 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/fix/vmap-get-seg-ids-pos

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Feb 19, 2026

Description

What is the bug ?

TE provides a convenience function from_segment_ids_and_pos() which allows users to pass only segment ids and the function returns a SequenceDescriptor with internally generated segment pos and passed segment ids.

As mentioned in Issue #2685 , if a user were to vmap a function forward() which i) accepts the q,k,v,segment ids and then ii) calls from_segment_ids_and_pos() followed by iii) a call to DPA(), what happens is that JAX sees the segment ids as vmapped hence an extra leading dimension is added (e.g. 1,2,128) whereas the segment offsets are not given a leading dimension (e.g. 2,128). This results in the FusedAttn primitive impl() assert being triggered due to a shape mismatch between seg ids and seg pos as mentioned in issue #2685

What is the root cause for the bug ?

On debugging, it can be seen that the shape starts differing when the batcher is being traced for the FusedAttn primitive.
segment_ids in the primitive: treated as vmapped inputs hence batched → (1, 2, 128).
segment_pos in the primitive: treated as derived within the function hence not batched → (2, 128).

This PR aims to resolve this ensuring that segment_pos has the same leading batching dims as segment_ids so the end user can vmap wrap the TE API calls without worrying about the batching in TE.

Fixes #2685

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Ensure that the segment pos leading batch dims match that of segment ids in fused attn primitive's batcher

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

… the TE constructed segment pos are not thereby causing mismatches in impl()

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani self-assigned this Feb 19, 2026
@KshitijLakhani
Copy link
Collaborator Author

/te-ci jax L0 L1 L2

@KshitijLakhani KshitijLakhani marked this pull request as ready for review February 20, 2026 06:54
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Greptile Summary

This PR fixes a shape mismatch bug in JAX's FusedAttn primitive batchers that occurs when using vmap with from_segment_ids_and_pos(). When segment_ids are passed to a vmapped function and segment_pos are generated internally, JAX traces segment_ids as batched (extra leading dimension) but segment_pos as non-batched, causing the impl() to fail with a shape assertion error.

Key Changes:

  • Added batch dimension synchronization logic to both FusedAttnFwdPrimitive.batcher() and FusedAttnBwdPrimitive.batcher()
  • When segment_ids have batch dims but segment_pos don't, the code now expands segment_pos with matching leading dimensions via lax.expand_dims() and broadcasts to match segment_ids shape
  • Updates the batch_dims tuple to reflect the new batching for segment_pos tensors
  • Includes validation assertions to ensure shape compatibility before expansion

Potential Improvements:

  • The empty tensor check only validates the first segment_ids tensor (q) before processing both q and kv pairs in the loop
  • Consider adding per-tensor empty checks or documenting the assumption that q and kv segment tensors have consistent emptiness

Confidence Score: 4/5

  • This PR is safe to merge with low risk - it fixes a legitimate bug with a targeted solution
  • The fix appropriately addresses the root cause by synchronizing batch dimensions between segment_ids and segment_pos. The logic is well-commented and includes shape validation assertions. Minor style improvements suggested around empty tensor handling, but these don't affect correctness for the target use case. No tests included in the PR, but the fix aligns with the issue description and handles the vmap edge case properly.
  • No files require special attention - the changes are localized to two batcher methods with symmetric implementations

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/attention.py Adds batching dimension synchronization logic to FusedAttn forward and backward primitives to handle vmap edge cases where segment_ids and segment_pos have mismatched batch dimensions

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[User calls vmap-wrapped function] --> B[Function generates segment_pos internally]
    B --> C[JAX traces FusedAttn batcher]
    C --> D{Check: seg_id_bdim != None AND seg_pos_bdim == None?}
    D -->|No| E[Pass args unchanged]
    D -->|Yes| F[Detect dimension mismatch]
    F --> G[For each segment_ids/pos pair]
    G --> H[Calculate leading_bdim = segment_ids.ndim - segment_pos.ndim]
    H --> I[Expand segment_pos with leading dimensions]
    I --> J[Broadcast segment_pos to match segment_ids shape]
    J --> K[Update batch_dims tuple]
    K --> L[Call primitive with synchronized shapes]
    E --> M[Primitive impl receives args]
    L --> M
Loading

Last reviewed commit: da19f26

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

JAX vmap issue with TE Attention

1 participant