diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index eb2ebcff39..eec7cb40d0 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -104,7 +104,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( } const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; - bool generate_stats = !return_max_logit; + bool generate_stats = true; // Always return stats try { FADescriptor_v1 descriptor{ b, @@ -335,7 +335,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( sdpa_options.set_sink_token(softmax_offset); } - std::shared_ptr Max, Sum_Exp; + std::shared_ptr Max; if (is_ragged_q && cudnn_runtime_version >= 90600) { offset_stats = mha_graph->tensor(fe::graph::Tensor_attributes() @@ -349,19 +349,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( .set_name("Max") .set_dim({b, h, s_q, 1}) .set_data_type(fe::DataType_t::FLOAT)); - Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes() - .set_name("Sum_Exp") - .set_dim({b, h, s_q, 1}) - .set_data_type(fe::DataType_t::FLOAT)); if (is_ragged_q && cudnn_runtime_version >= 90600) { Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); - Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); } else { Max->set_stride({h * s_q, s_q, 1, 1}); - Sum_Exp->set_stride({h * s_q, s_q, 1, 1}); } sdpa_options.set_logit_max(Max); - sdpa_options.set_score_sum_exp(Sum_Exp); } auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); @@ -379,13 +372,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( O->set_ragged_offset(offset_o); } - if (!return_max_logit) { - Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); - if (is_ragged_q && cudnn_runtime_version >= 90600) { - Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); - } else { - Stats->set_stride({h * s_q, s_q, 1, 1}); - } + Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); + if (is_ragged_q && cudnn_runtime_version >= 90600) { + Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); + } else { + Stats->set_stride({h * s_q, s_q, 1, 1}); } std::tuple, // Q @@ -395,7 +386,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( std::shared_ptr> // O key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); auto Stats_tuple = - generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp); + return_max_logit ? std::make_tuple(Stats, Max) : std::make_tuple(Stats, nullptr); auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); auto softmax_offset_tuple = is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); @@ -1125,6 +1116,16 @@ void fused_attn_arbitrary_seqlen_fwd( size_t i = 0; if (Aux_CTX_Tensors->size == 0) { const auto cudnn_runtime_version = cudnnGetVersion(); + + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + output_S->data.dptr = nullptr; + if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { + output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; + } else { + output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; + } + output_S->data.dtype = DType::kFloat32; + if (return_max_logit) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_Max->data.dptr = nullptr; @@ -1134,23 +1135,6 @@ void fused_attn_arbitrary_seqlen_fwd( output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; } output_Max->data.dtype = DType::kFloat32; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_Sum_Exp->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1}; - } else { - output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_Sum_Exp->data.dtype = DType::kFloat32; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - output_S->data.dptr = nullptr; - if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { - output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; - } else { - output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; - } - output_S->data.dtype = DType::kFloat32; } Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); @@ -1174,14 +1158,12 @@ void fused_attn_arbitrary_seqlen_fwd( Aux_CTX_Tensors->size = i; } else if (Aux_CTX_Tensors->size >= 2) { + Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); + devPtrS1 = output_S->data.dptr; + if (return_max_logit) { Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_Max->data.dptr; - Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS2 = output_Sum_Exp->data.dptr; - } else { - Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); - devPtrS1 = output_S->data.dptr; + devPtrS2 = output_Max->data.dptr; } Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); output_rng_state->data.dptr = rng_state->data.dptr; diff --git a/transformer_engine/common/fused_attn/utils.h b/transformer_engine/common/fused_attn/utils.h index 08a56cda6b..1ec1616c4a 100644 --- a/transformer_engine/common/fused_attn/utils.h +++ b/transformer_engine/common/fused_attn/utils.h @@ -118,7 +118,7 @@ struct FADescriptor_v1 { cudnn_frontend::DataType_t o_tensor_type; cudnn_frontend::DataType_t do_tensor_type; cudnn_frontend::DataType_t dqkv_tensor_type; - bool generate_max_sum_exp; + bool return_max_logit; bool operator<(const FADescriptor_v1 &rhs) const { return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k, @@ -126,7 +126,7 @@ struct FADescriptor_v1 { bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal, deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, - dqkv_tensor_type, generate_max_sum_exp) < + dqkv_tensor_type, return_max_logit) < std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k, rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k, rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv, @@ -134,7 +134,7 @@ struct FADescriptor_v1 { rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type, - rhs.dqkv_tensor_type, rhs.generate_max_sum_exp); + rhs.dqkv_tensor_type, rhs.return_max_logit); } }; diff --git a/transformer_engine/common/include/transformer_engine/fused_attn.h b/transformer_engine/common/include/transformer_engine/fused_attn.h index cddd3d7506..9f62c2a089 100644 --- a/transformer_engine/common/include/transformer_engine/fused_attn.h +++ b/transformer_engine/common/include/transformer_engine/fused_attn.h @@ -206,7 +206,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout); * \param[in] head_dim_v The head dimension of V. * \param[in] window_size_left Sliding window size (the left half). * \param[in] window_size_right Sliding window size (the right half). - * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] return_max_logit Whether to produce Max along with Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] deterministic Whether determinism is required or not. */ @@ -260,7 +260,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( * \param[in] max_seqlen Max sequence length used for computing, * it may be >= max(seqlen_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. - * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] return_max_logit Whether to produce Max along with Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. @@ -400,7 +400,7 @@ void nvte_fused_attn_bwd_qkvpacked( * \param[in] max_seqlen_kv Max sequence length used for computing for KV. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. - * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] return_max_logit Whether to produce Max along with Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. @@ -553,7 +553,7 @@ void nvte_fused_attn_bwd_kvpacked( * \param[in] max_seqlen_kv Max sequence length used for computing for K and V. * it may be >= max(seqlen_kv_i) for i=0,...batch_size-1. * \param[in] is_training Whether this is in training mode or inference. - * \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats. + * \param[in] return_max_logit Whether to produce Max along with Stats. * \param[in] cuda_graph Whether cuda graph capture is enabled or not. * \param[in] attn_scale Scaling factor for Q * K.T. * \param[in] dropout Dropout probability. diff --git a/transformer_engine/pytorch/cpp_extensions/fused_attn.py b/transformer_engine/pytorch/cpp_extensions/fused_attn.py index 101e5b2525..fffc9faf28 100644 --- a/transformer_engine/pytorch/cpp_extensions/fused_attn.py +++ b/transformer_engine/pytorch/cpp_extensions/fused_attn.py @@ -339,14 +339,14 @@ def fused_attn_fwd( if return_max_logit: qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] - # thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1] - # bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] - # sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] - stats = output_tensors[1] + torch.log(output_tensors[2]) + # thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1] + # bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] + # sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here) + + aux_ctx_tensors = [output_tensors[1]] # "Stats" amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3) # Max -> max_logit [h] - max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype) - aux_ctx_tensors = [stats] + max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype) aux_ctx_tensors.extend(output_tensors[3:]) return output_tensors[0], aux_ctx_tensors, max_logit diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index bf62db8c33..ff60bb87bb 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -259,16 +259,16 @@ std::vector fused_attn_fwd( // f16_max512 : S [b, h, sq, skv] // f16_arbitrary: // return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] - // return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] + // return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1] // fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2] size_t i = 0; at::Tensor output_tensor; - // intermediate softmax tensor, S or M + // intermediate softmax tensor, S or M (for fp8) output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])), static_cast(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false); set_tensor_param(i++, output_tensor); - // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor + // fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { output_tensor = allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),