From 4aab56c7ca68606b7c55487b694cb874b817582d Mon Sep 17 00:00:00 2001 From: Fabian Joswig Date: Fri, 20 Feb 2026 20:59:36 +0100 Subject: [PATCH] [PyTorch] Zero-initialize learnable softmax_offset in DotProductAttention DotProductAttention used torch.empty() for the learnable softmax_offset parameter. Unlike all other TransformerEngineBaseModule subclasses, DotProductAttention does not call reset_parameters() in __init__, so the deferred initialization that would normally overwrite the empty tensor is never invoked, leaving the parameter with uninitialized memory. The JAX implementation explicitly uses nn.initializers.zeros for this parameter. This aligns the PyTorch behavior by using torch.zeros(). Signed-off-by: Fabian Joswig --- .../attention/dot_product_attention/dot_product_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 64db4646f6..2dc42be18a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -439,7 +439,7 @@ def __init__( if self.softmax_type == "learnable": self.register_parameter( "softmax_offset", - Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")), + Parameter(torch.zeros(self.num_attention_heads // self.tp_size, device="cuda")), get_rng_state_tracker=get_rng_state_tracker, )