diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 0c5a519813..c86419f189 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -610,6 +610,7 @@ def get_attention_backend( qkv_layout, ) use_fused_attention = False + # TODO: KL check if this condition is now supported or not ? if ( device_compute_capability == (12, 0) and (head_dim_qk > 128 or head_dim_qk % 8 != 0) @@ -690,11 +691,11 @@ def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dt "padding between sequences, i.e. [a, a, PAD, b, b, b, PAD, c, PAD]" ) use_flash_attention = False - if device_compute_capability == (12, 0): + if device_compute_capability == (12, 0) and cudnn_version < (9, 18, 1): if use_fused_attention: logger.debug( "Disabling FusedAttention as qkv_format = thd is" - " not supported for compute capability = sm120" + " not supported for compute capability = sm120 and cuDNN version < 9.18.1" ) use_fused_attention = False