[PyTorch] Zero-initialize learnable softmax_offset in DotProductAttention#2694
Open
fjosw wants to merge 1 commit intoNVIDIA:mainfrom
Open
[PyTorch] Zero-initialize learnable softmax_offset in DotProductAttention#2694fjosw wants to merge 1 commit intoNVIDIA:mainfrom
fjosw wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
…tion DotProductAttention used torch.empty() for the learnable softmax_offset parameter. Unlike all other TransformerEngineBaseModule subclasses, DotProductAttention does not call reset_parameters() in __init__, so the deferred initialization that would normally overwrite the empty tensor is never invoked, leaving the parameter with uninitialized memory. The JAX implementation explicitly uses nn.initializers.zeros for this parameter. This aligns the PyTorch behavior by using torch.zeros(). Signed-off-by: Fabian Joswig <fjosw@users.noreply.github.com>
Contributor
Greptile SummaryThis PR fixes a parameter initialization bug in Key Points:
Confidence Score: 5/5
Important Files Changed
Last reviewed commit: 4aab56c |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
The PyTorch implementation of DotProductAttention initializes the learnable
softmax_offsetparameter withtorch.empty(), which leaves it containing uninitialized memory. Unlike all otherTransformerEngineBaseModulesubclasses (Linear, LayerNormLinear, LayerNormMLP, GroupedLinear), DotProductAttention does not callself.reset_parameters() in its__init__, so the deferred initialization system that would normally overwrite the torch.empty() contents is never invoked. The JAX implementation explicitly uses nn.initializers.zeros for this parameter. This fix aligns the PyTorch behavior by using torch.zeros() directly. In Megatron-LM this is not a problem because the paramter is initialised explicitly but when used in isolation this can lead to problems.Fixes # (issue)
Type of change
Changes
Checklist: