diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 64db4646f6..2dc42be18a 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -439,7 +439,7 @@ def __init__( if self.softmax_type == "learnable": self.register_parameter( "softmax_offset", - Parameter(torch.empty(self.num_attention_heads // self.tp_size, device="cuda")), + Parameter(torch.zeros(self.num_attention_heads // self.tp_size, device="cuda")), get_rng_state_tracker=get_rng_state_tracker, )