[PyTorch] Error out if constructing LayerNormLinear with row tensor parallelism#2688
[PyTorch] Error out if constructing LayerNormLinear with row tensor parallelism#2688timmoon10 wants to merge 3 commits intoNVIDIA:mainfrom
LayerNormLinear with row tensor parallelism#2688Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Greptile SummaryAdded validation to prevent construction of Key changes:
Issues found:
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[LayerNormLinear.__init__] --> B{parallel_mode in GemmParallelModes?}
B -->|No| C[AssertionError: parallel_mode not supported]
B -->|Yes| D{parallel_mode == 'row'?}
D -->|Yes| E[NotImplementedError: Normalization does not support TP distribution]
D -->|No| F{parallel_mode == 'column'?}
F -->|Yes| G[Divide out_features by tp_size]
F -->|No| H[No TP distribution]
G --> I[Continue initialization]
H --> I
style E fill:#f66,stroke:#f00,color:#fff
style C fill:#f66,stroke:#f00,color:#fff
style I fill:#6f6,stroke:#0f0
Last reviewed commit: 1ef2ae0 |
Additional Comments (1)
|
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <tmoon@nvidia.com>
|
/te-ci pytorch L1 |
Additional Comments (3)
|
Description
LayerNormLinearmodules with row tensor-parallel have input tensors that are sharded along the inner dimension:TransformerEngine/transformer_engine/pytorch/module/layernorm_linear.py
Lines 1199 to 1200 in 7e48fa1
However, we currently don't support tensor-parallel LayerNorm or RMSNorm, which would involve a tensor-parallel all-reduce to compute statistics. If the user attempts to run
LayerNormLinearwith row tensor parallelism, then they experience an illegal memory access when the norm kernel accesses values in the unsharded norm weight tensor. We haven't experienced problems so far because row TP is usually used for the proj and fc2 layers, which are usuallyLinears.This PR adds an error message to make the failure more obvious.
Type of change
Changes
LayerNormLinearwith row tensor parallelismChecklist: