Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
}

const DType ragged_offset_type = cudnn_runtime_version >= 90500 ? DType::kInt64 : DType::kInt32;
bool generate_stats = !return_max_logit;
bool generate_stats = true; // Always return stats
try {
FADescriptor_v1 descriptor{
b,
Expand Down Expand Up @@ -335,7 +335,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
sdpa_options.set_sink_token(softmax_offset);
}

std::shared_ptr<fe::graph::Tensor_attributes> Max, Sum_Exp;
std::shared_ptr<fe::graph::Tensor_attributes> Max;
if (is_ragged_q && cudnn_runtime_version >= 90600) {
offset_stats =
mha_graph->tensor(fe::graph::Tensor_attributes()
Expand All @@ -349,19 +349,12 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
.set_name("Max")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
Sum_Exp = mha_graph->tensor(fe::graph::Tensor_attributes()
.set_name("Sum_Exp")
.set_dim({b, h, s_q, 1})
.set_data_type(fe::DataType_t::FLOAT));
if (is_ragged_q && cudnn_runtime_version >= 90600) {
Max->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
Sum_Exp->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Max->set_stride({h * s_q, s_q, 1, 1});
Sum_Exp->set_stride({h * s_q, s_q, 1, 1});
}
sdpa_options.set_logit_max(Max);
sdpa_options.set_score_sum_exp(Sum_Exp);
}

auto [O, Stats] = mha_graph->sdpa(Q, K, V, std::move(sdpa_options));
Expand All @@ -379,13 +372,11 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
O->set_ragged_offset(offset_o);
}

if (!return_max_logit) {
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Stats->set_stride({h * s_q, s_q, 1, 1});
}
Stats->set_output(true).set_data_type(fe::DataType_t::FLOAT).set_dim({b, h, s_q, 1});
if (is_ragged_q && cudnn_runtime_version >= 90600) {
Stats->set_stride({h * s_q, 1, h, 1}).set_ragged_offset(offset_stats);
} else {
Stats->set_stride({h * s_q, s_q, 1, 1});
}

std::tuple<std::shared_ptr<fe::graph::Tensor_attributes>, // Q
Expand All @@ -395,7 +386,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
std::shared_ptr<fe::graph::Tensor_attributes>> // O
key_tensors_tuple = std::make_tuple(Q, K, V, attn_scale, O);
auto Stats_tuple =
generate_stats ? std::make_tuple(Stats, nullptr) : std::make_tuple(Max, Sum_Exp);
return_max_logit ? std::make_tuple(Stats, Max) : std::make_tuple(Stats, nullptr);
auto bias_tuple = is_bias ? std::make_tuple(bias) : std::make_tuple(nullptr);
auto softmax_offset_tuple =
is_softmax_offset ? std::make_tuple(softmax_offset) : std::make_tuple(nullptr);
Expand Down Expand Up @@ -1125,6 +1116,16 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t i = 0;
if (Aux_CTX_Tensors->size == 0) {
const auto cudnn_runtime_version = cudnnGetVersion();

Copy link
Collaborator

@cyanguwa cyanguwa Feb 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to make these changes in the "Aux_CTX_Tensors->size == 0" sections in _fwd/bwd_qkvpacked/kvpacked APIs as well. Please check. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like I don't need to because nvte_fused...qvpacked are in fused_attn.cpp which calls fused_attn_f16_arbitrary... just like regular nvte_fused_fwd/bwd

Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;

if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Max->data.dptr = nullptr;
Expand All @@ -1134,23 +1135,6 @@ void fused_attn_arbitrary_seqlen_fwd(
output_Max->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Max->data.dtype = DType::kFloat32;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_Sum_Exp->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_Sum_Exp->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_Sum_Exp->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_Sum_Exp->data.dtype = DType::kFloat32;
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
output_S->data.shape = {num_tokens_q, num_attn_heads, 1};
} else {
output_S->data.shape = {batch, num_attn_heads, max_seqlen_q, 1};
}
output_S->data.dtype = DType::kFloat32;
}

Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
Expand All @@ -1174,14 +1158,12 @@ void fused_attn_arbitrary_seqlen_fwd(

Aux_CTX_Tensors->size = i;
} else if (Aux_CTX_Tensors->size >= 2) {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_S->data.dptr;

if (return_max_logit) {
Tensor *output_Max = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_Max->data.dptr;
Tensor *output_Sum_Exp = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS2 = output_Sum_Exp->data.dptr;
} else {
Tensor *output_S = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
devPtrS1 = output_S->data.dptr;
devPtrS2 = output_Max->data.dptr;
}
Tensor *output_rng_state = convertNVTETensorCheck(Aux_CTX_Tensors->tensors[i++]);
output_rng_state->data.dptr = rng_state->data.dptr;
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/common/fused_attn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,23 +118,23 @@ struct FADescriptor_v1 {
cudnn_frontend::DataType_t o_tensor_type;
cudnn_frontend::DataType_t do_tensor_type;
cudnn_frontend::DataType_t dqkv_tensor_type;
bool generate_max_sum_exp;
bool return_max_logit;

bool operator<(const FADescriptor_v1 &rhs) const {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, bias_sq,
bias_skv, attnScale, isTraining, dropoutProbability, layout, mask_type,
softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
deterministic, bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type,
dqkv_tensor_type, generate_max_sum_exp) <
dqkv_tensor_type, return_max_logit) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.bias_sq, rhs.bias_skv,
rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
rhs.mask_type, rhs.softmax_type, rhs.window_size_left, rhs.window_size_right,
rhs.bottom_right_diagonal, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
rhs.dqkv_tensor_type, rhs.return_max_logit);
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ NVTE_QKV_Format nvte_get_kv_format(NVTE_QKV_Layout qkv_layout);
* \param[in] head_dim_v The head dimension of V.
* \param[in] window_size_left Sliding window size (the left half).
* \param[in] window_size_right Sliding window size (the right half).
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] return_max_logit Whether to produce Max along with Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] deterministic Whether determinism is required or not.
*/
Expand Down Expand Up @@ -260,7 +260,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
* \param[in] max_seqlen Max sequence length used for computing,
* it may be >= max(seqlen_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] return_max_logit Whether to produce Max along with Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
Expand Down Expand Up @@ -400,7 +400,7 @@ void nvte_fused_attn_bwd_qkvpacked(
* \param[in] max_seqlen_kv Max sequence length used for computing for KV.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] return_max_logit Whether to produce Max along with Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
Expand Down Expand Up @@ -553,7 +553,7 @@ void nvte_fused_attn_bwd_kvpacked(
* \param[in] max_seqlen_kv Max sequence length used for computing for K and V.
* it may be >= max(seqlen_kv_i) for i=0,...batch_size-1.
* \param[in] is_training Whether this is in training mode or inference.
* \param[in] return_max_logit Whether to produce Max and Sum_Exp, or Stats.
* \param[in] return_max_logit Whether to produce Max along with Stats.
* \param[in] cuda_graph Whether cuda graph capture is enabled or not.
* \param[in] attn_scale Scaling factor for Q * K.T.
* \param[in] dropout Dropout probability.
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,14 +339,14 @@ def fused_attn_fwd(

if return_max_logit:
qkv_format = qkv_layout.replace("3", "").replace("2", "").split("_")[0]
# thd: output_tensors: out [tq, h, d], Max [tq, h, 1], Sum_Exp [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1]
stats = output_tensors[1] + torch.log(output_tensors[2])
# thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1]
# bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1]
# sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the "there's no typo here" :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deliberately added it because I didn't believe it and checked the shapes myself :P


aux_ctx_tensors = [output_tensors[1]] # "Stats"
amax_dims = (0, 2) if qkv_format == "thd" else (0, 2, 3)
# Max -> max_logit [h]
max_logit = torch.amax(output_tensors[1], dim=amax_dims).to(dtype=output_tensors[0].dtype)
aux_ctx_tensors = [stats]
max_logit = torch.amax(output_tensors[2], dim=amax_dims).to(dtype=output_tensors[0].dtype)
Copy link
Collaborator

@KshitijLakhani KshitijLakhani Feb 19, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I understood this incorrectly, but isn't TE now also suppose to receive max from cuDNN directly (like stats, but with stats it is always true and with max it cn be toggled) rather than calling amax() in TE ?

(Sudhakar: Why am I able to update your comment? )

Copy link
Collaborator Author

@sudhakarsingh27 sudhakarsingh27 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuDNN returns Max ([b, h, sq, 1]) so it's an additional softmax statistic (apparently, the subset (Stats, Max) is enough for cuDNN bwd rather than the full set (Stats, SumExp, Max)).

Further, for muon, we need do amax on it to get a dimension [h] tensor. return_max_logit in TE controls whether to fetch Max from cuDNN.

Perf wise, it'd be nice for cuDNN to do additional reduction to return the [h] shaped tensor for muon as well but that's not the scope of this PR.

(Kshitij: looks like I can as well)

aux_ctx_tensors.extend(output_tensors[3:])
return output_tensors[0], aux_ctx_tensors, max_logit

Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/csrc/extensions/attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,16 +259,16 @@ std::vector<py::object> fused_attn_fwd(
// f16_max512 : S [b, h, sq, skv]
// f16_arbitrary:
// return_max_logit=false: S [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// return_max_logit=true: Max [b, h, sq, 1], Sum_Exp [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// return_max_logit=true: S [b, h, sq, 1], Max [b, h, sq, 1], rng_state [2], (optional) Bias [1, h, sq, skv], (optional) SoftmaxOffset [1, h, 1, 1]
// fp8 : M [b, h, sq, 1], ZInv [b, h, sq, 1], rng_state [2]
size_t i = 0;
at::Tensor output_tensor;
// intermediate softmax tensor, S or M
// intermediate softmax tensor, S or M (for fp8)
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
static_cast<DType>(nvte_tensor_type(nvte_aux_tensor_pack.tensors[i])), false);
set_tensor_param(i++, output_tensor);
// fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Sum_Exp tensor
// fp8 has an additional softmax stats tensor, ZInv; return_max_logit=true has an additional Max tensor
if (return_max_logit || qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) {
output_tensor =
allocateSpace(nvte_shape_to_vector(nvte_tensor_shape(nvte_aux_tensor_pack.tensors[i])),
Expand Down
Loading