-
Notifications
You must be signed in to change notification settings - Fork 640
[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True
#2677
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3fb19fc
5b40701
5d479ad
296fb9f
24bfd45
fd42feb
2d7b51b
260380b
7a5ab35
9710810
07db752
f8b1a68
b5b2b9d
8f40cab
56e46fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -104,7 +104,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| } | ||
|
|
||
| const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32; | ||
| bool generate_stats = !return_max_logit; | ||
| bool generate_stats = true; // Always return stats | ||
| try { | ||
| FADescriptor_v1 descriptor{ | ||
| b, | ||
|
|
@@ -335,7 +335,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| sdpa_options.set_sink_token(softmax_offset); | ||
| } | ||
|
|
||
| std::shared_ptr<fe::graph::Tensor_attributes> Max, Sum_Exp; | ||
| std::shared_ptr<fe::graph::Tensor_attributes> Max; | ||
| if (is_ragged_q && cudnn_runtime_version >= 90600) { | ||
| offset_stats = | ||
| mha_graph->tensor(fe::graph::Tensor_attributes() | ||
|
|
@@ -349,19 +349,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| .set_name("Max") | ||
| .set_dim({b, h, s_q, 1}) | ||
| .set_data_type(fe::DataType_t::FLOAT)); | ||
| Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes() | ||
| .set_name("Sum_Exp") | ||
| .set_dim({b, h, s_q, 1}) | ||
| .set_data_type(fe::DataType_t::FLOAT)); | ||
| if (is_ragged_q && cudnn_runtime_version >= 90600) { | ||
| Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); | ||
| Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); | ||
| } else { | ||
| Max->set_stride({h * s_q, s_q, 1, 1}); | ||
| Sum_Exp->set_stride({h * s_q, s_q, 1, 1}); | ||
| } | ||
| sdpa_options.set_logit_max(Max); | ||
| sdpa_options.set_score_sum_exp(Sum_Exp); | ||
| } | ||
|
|
||
| auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options)); | ||
|
|
@@ -379,13 +372,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| O->set_ragged_offset(offset_o); | ||
| } | ||
|
|
||
| if (!return_max_logit) { | ||
| Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); | ||
| if (is_ragged_q && cudnn_runtime_version >= 90600) { | ||
| Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); | ||
| } else { | ||
| Stats->set_stride({h * s_q, s_q, 1, 1}); | ||
| } | ||
| Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1}); | ||
| if (is_ragged_q && cudnn_runtime_version >= 90600) { | ||
| Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats); | ||
| } else { | ||
| Stats->set_stride({h * s_q, s_q, 1, 1}); | ||
| } | ||
|
|
||
| std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q | ||
|
|
@@ -395,7 +386,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( | |
| std::shared_ptr<fe::graph::Tensor_attributes>> // O | ||
| key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O); | ||
| auto Stats_tuple = | ||
| generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp); | ||
| return_max_logit ? std::make_tuple(Stats, Max) : std::make_tuple(Stats, nullptr); | ||
| auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr); | ||
| auto softmax_offset_tuple = | ||
| is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr); | ||
|
|
@@ -1125,6 +1116,16 @@ void fused_attn_arbitrary_seqlen_fwd( | |
| size_t i = 0; | ||
| if (Aux_CTX_Tensors->size == 0) { | ||
| const auto cudnn_runtime_version = cudnnGetVersion(); | ||
|
|
||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might need to make these changes in the "Aux_CTX_Tensors->size == 0" sections in _fwd/bwd_qkvpacked/kvpacked APIs as well. Please check. Thanks!
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looks like I don't need to because |
||
| Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| output_S->data.dptr = nullptr; | ||
| if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { | ||
| output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; | ||
| } else { | ||
| output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; | ||
| } | ||
| output_S->data.dtype = DType::kFloat32; | ||
|
|
||
| if (return_max_logit) { | ||
| Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| output_Max->data.dptr = nullptr; | ||
|
|
@@ -1134,23 +1135,6 @@ void fused_attn_arbitrary_seqlen_fwd( | |
| output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; | ||
| } | ||
| output_Max->data.dtype = DType::kFloat32; | ||
| Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| output_Sum_Exp->data.dptr = nullptr; | ||
| if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { | ||
| output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1}; | ||
| } else { | ||
| output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; | ||
| } | ||
| output_Sum_Exp->data.dtype = DType::kFloat32; | ||
| } else { | ||
| Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| output_S->data.dptr = nullptr; | ||
| if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { | ||
| output_S->data.shape = {num_tokens_q, num_attn_heads, 1}; | ||
| } else { | ||
| output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1}; | ||
| } | ||
| output_S->data.dtype = DType::kFloat32; | ||
| } | ||
|
|
||
| Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
|
|
@@ -1174,14 +1158,12 @@ void fused_attn_arbitrary_seqlen_fwd( | |
|
|
||
| Aux_CTX_Tensors->size = i; | ||
| } else if (Aux_CTX_Tensors->size >= 2) { | ||
| Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| devPtrS1 = output_S->data.dptr; | ||
|
|
||
| if (return_max_logit) { | ||
| Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| devPtrS1 = output_Max->data.dptr; | ||
| Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| devPtrS2 = output_Sum_Exp->data.dptr; | ||
| } else { | ||
| Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| devPtrS1 = output_S->data.dptr; | ||
| devPtrS2 = output_Max->data.dptr; | ||
| } | ||
| Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]); | ||
| output_rng_state->data.dptr = rng_state->data.dptr; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -339,14 +339,14 @@ def fused_attn_fwd( | |
|
|
||
| if return_max_logit: | ||
| qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0] | ||
| # thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1] | ||
| # bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] | ||
| # sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1] | ||
| stats = output_tensors[1] + torch.log(output_tensors[2]) | ||
| # thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1] | ||
| # bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] | ||
| # sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need the "there's no typo here" :)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I deliberately added it because I didn't believe it and checked the shapes myself :P |
||
|
|
||
| aux_ctx_tensors = [output_tensors[1]] # "Stats" | ||
| amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3) | ||
| # Max -> max_logit [h] | ||
| max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype) | ||
| aux_ctx_tensors = [stats] | ||
| max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe I understood this incorrectly, but isn't TE now also suppose to receive max from cuDNN directly (like stats, but with stats it is always true and with max it cn be toggled) rather than calling amax() in TE ? (Sudhakar: Why am I able to update your comment? )
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cuDNN returns Max ([b, h, sq, 1]) so it's an additional softmax statistic (apparently, the subset ( Further, for muon, we need do amax on it to get a dimension [h] tensor. Perf wise, it'd be nice for cuDNN to do additional reduction to return the [h] shaped tensor for muon as well but that's not the scope of this PR. (Kshitij: looks like I can as well) |
||
| aux_ctx_tensors.extend(output_tensors[3:]) | ||
| return output_tensors[0], aux_ctx_tensors, max_logit | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.