From 998b3b8667f0a73bb598296bd52785c9ab8cf2ad Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani Date: Fri, 20 Feb 2026 18:29:43 +0000 Subject: [PATCH 1/2] Enable sm120 support for fused attn if cuDNN is 9.18.1+ Signed-off-by: Kshitij Lakhani --- .../pytorch/attention/dot_product_attention/utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 0c5a519813..2932e8cdf5 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -610,6 +610,7 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False + #TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) @@ -690,11 +691,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and cudnn_version < (9, 18, 1): if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_format = thd is" - " not supported for compute capability = sm120" + " not supported for compute capability = sm120 and cuDNN version < 9.18.1" ) use_fused_attention = False From dc282ea5631ce1e573e7423985f0cf16bbe9fbc0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Feb 2026 18:42:10 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 2932e8cdf5..c86419f189 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -610,7 +610,7 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False - #TODO: KL check if this condition is now supported or not ? + # TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0)