Skip to content

Comments

[PyTorch] Zero-initialize learnable softmax_offset in DotProductAttention#2694

Open
fjosw wants to merge 1 commit intoNVIDIA:mainfrom
fjosw:fix/softmax-offset-zero-init-v2
Open

[PyTorch] Zero-initialize learnable softmax_offset in DotProductAttention#2694
fjosw wants to merge 1 commit intoNVIDIA:mainfrom
fjosw:fix/softmax-offset-zero-init-v2

Conversation

@fjosw
Copy link
Contributor

@fjosw fjosw commented Feb 20, 2026

Description

The PyTorch implementation of DotProductAttention initializes the learnable softmax_offset parameter with torch.empty(), which leaves it containing uninitialized memory. Unlike all other TransformerEngineBaseModule subclasses (Linear, LayerNormLinear, LayerNormMLP, GroupedLinear), DotProductAttention does not call self.reset_parameters() in its __init__, so the deferred initialization system that would normally overwrite the torch.empty() contents is never invoked. The JAX implementation explicitly uses nn.initializers.zeros for this parameter. This fix aligns the PyTorch behavior by using torch.zeros() directly. In Megatron-LM this is not a problem because the paramter is initialised explicitly but when used in isolation this can lead to problems.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Change torch.empty() to torch.zeros() when creating the learnable softmax_offset parameter in DotProductAttention, ensuring it is zero-initialized rather than containing uninitialized memory.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…tion

DotProductAttention used torch.empty() for the learnable softmax_offset
parameter. Unlike all other TransformerEngineBaseModule subclasses,
DotProductAttention does not call reset_parameters() in __init__, so the
deferred initialization that would normally overwrite the empty tensor is
never invoked, leaving the parameter with uninitialized memory.

The JAX implementation explicitly uses nn.initializers.zeros for this
parameter. This aligns the PyTorch behavior by using torch.zeros().

Signed-off-by: Fabian Joswig <fjosw@users.noreply.github.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Greptile Summary

This PR fixes a parameter initialization bug in DotProductAttention where the learnable softmax_offset parameter was initialized with torch.empty(), leaving it with uninitialized memory values. The fix changes this to torch.zeros() to ensure proper zero-initialization.

Key Points:

  • The JAX implementation explicitly uses nn.initializers.zeros for this parameter (line 165 in transformer_engine/jax/flax/transformer.py)
  • The "off-by-one" softmax type already uses torch.zeros() in the same file (line 436-438)
  • Unlike other TransformerEngineBaseModule subclasses, DotProductAttention does not call self.reset_parameters() in __init__, so the deferred initialization system never overwrites the uninitialized values
  • This bug would only manifest when using DotProductAttention in isolation without Megatron-LM's initialization
  • Existing tests cover the learnable softmax type and should verify the fix works correctly

Confidence Score: 5/5

  • This PR is safe to merge with minimal risk
  • Single-line bug fix that corrects improper initialization, aligns with JAX implementation and the "off-by-one" case in the same file, has existing test coverage, and follows established patterns
  • No files require special attention

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py Changed torch.empty() to torch.zeros() for learnable softmax_offset parameter initialization, ensuring zero-initialization instead of uninitialized memory

Last reviewed commit: 4aab56c

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1 file reviewed, no comments

Edit Code Review Agent Settings | Greptile

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant