[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692
Open
KshitijLakhani wants to merge 5 commits intoNVIDIA:mainfrom
Open
[JAX] Fix batcher in FusedAttn primitive for when seg ids bdims != seg pos bdims#2692KshitijLakhani wants to merge 5 commits intoNVIDIA:mainfrom
KshitijLakhani wants to merge 5 commits intoNVIDIA:mainfrom