Skip to content

Comments

[PyTorch] Error out if constructing LayerNormLinear with row tensor parallelism#2688

Open
timmoon10 wants to merge 3 commits intoNVIDIA:mainfrom
timmoon10:tmoon/row-tp-layernorm-linear
Open

[PyTorch] Error out if constructing LayerNormLinear with row tensor parallelism#2688
timmoon10 wants to merge 3 commits intoNVIDIA:mainfrom
timmoon10:tmoon/row-tp-layernorm-linear

Conversation

@timmoon10
Copy link
Collaborator

Description

LayerNormLinear modules with row tensor-parallel have input tensors that are sharded along the inner dimension:

elif self.parallel_mode == "row":
self.in_features = divide(self.in_features, self.tp_size)

However, we currently don't support tensor-parallel LayerNorm or RMSNorm, which would involve a tensor-parallel all-reduce to compute statistics. If the user attempts to run LayerNormLinear with row tensor parallelism, then they experience an illegal memory access when the norm kernel accesses values in the unsharded norm weight tensor. We haven't experienced problems so far because row TP is usually used for the proj and fc2 layers, which are usually Linears.

This PR adds an error message to make the failure more obvious.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Error out if constructing LayerNormLinear with row tensor parallelism

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@timmoon10 timmoon10 added the bug Something isn't working label Feb 17, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 17, 2026

Greptile Summary

Added validation to prevent construction of LayerNormLinear with row tensor parallelism, which was causing illegal memory accesses due to unsupported tensor-parallel normalization operations.

Key changes:

  • Added NotImplementedError check in LayerNormLinear.__init__ when parallel_mode == "row"
  • Removed row-parallel test cases for LayerNormLinear from distributed comm-gemm overlap tests

Issues found:

  • Dead code remains in constructor after the validation check (lines 1203-1204, 1231-1237, 1388-1393)
  • The elif block at line 1203 and other row-parallel initialization logic will never execute
  • Consider cleaning up unreachable code paths for maintainability

Confidence Score: 3/5

  • Safe to merge but should clean up dead code paths for better maintainability
  • The core fix is correct and prevents the illegal memory access issue. Tests are appropriately updated. However, dead code remains after the validation check that should be removed to prevent future confusion
  • Pay attention to transformer_engine/pytorch/module/layernorm_linear.py - dead code should be cleaned up

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/layernorm_linear.py Added validation to prevent row-parallel construction, but left dead code in constructor and initialization logic
tests/pytorch/distributed/test_comm_gemm_overlap.py Correctly removed row-parallel test cases for LayerNormLinear

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[LayerNormLinear.__init__] --> B{parallel_mode in GemmParallelModes?}
    B -->|No| C[AssertionError: parallel_mode not supported]
    B -->|Yes| D{parallel_mode == 'row'?}
    D -->|Yes| E[NotImplementedError: Normalization does not support TP distribution]
    D -->|No| F{parallel_mode == 'column'?}
    F -->|Yes| G[Divide out_features by tp_size]
    F -->|No| H[No TP distribution]
    G --> I[Continue initialization]
    H --> I
    
    style E fill:#f66,stroke:#f00,color:#fff
    style C fill:#f66,stroke:#f00,color:#fff
    style I fill:#6f6,stroke:#0f0
Loading

Last reviewed commit: 1ef2ae0

greptile-apps[bot]

This comment was marked as outdated.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 17, 2026

Additional Comments (1)

transformer_engine/pytorch/module/layernorm_linear.py
Unreachable code: this elif branch will never execute since row parallel mode now raises NotImplementedError on line 1197. Consider removing these lines or moving the error check after this conditional.

@timmoon10

This comment was marked as outdated.

ptrendx
ptrendx previously approved these changes Feb 18, 2026
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 20, 2026

Additional Comments (3)

transformer_engine/pytorch/module/layernorm_linear.py
Dead code - the NotImplementedError raised on line 1197 prevents execution from reaching this row-parallel case


transformer_engine/pytorch/module/layernorm_linear.py
Dead code - row-parallel mode raises NotImplementedError at line 1197, so these will always be False

        # Row-parallel overlaps (disabled - not supported)
        self.ub_overlap_rs_fprop = False
        self.ub_overlap_ag_dgrad = False

transformer_engine/pytorch/module/layernorm_linear.py
Dead code - row-parallel mode is not supported (line 1197), so this will always be False

        # For RPL, bias has to be added after TP collectives
        # So it cannot be fused with the GEMM
        # Note: RPL is not supported for LayerNormLinear
        self.gemm_bias_unfused_add = False

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants