From 0f7674fbed6c68de59ab2be4b3e11eb43fb79790 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Tue, 17 Feb 2026 15:42:44 -0800 Subject: [PATCH 1/2] Error out if constructing LayerNormLinear with row tensor parallelism Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/layernorm_linear.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..965d145823 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1193,6 +1193,10 @@ def __init__( assert ( self.parallel_mode in GemmParallelModes ), f"parallel_mode {parallel_mode} not supported" + if self.parallel_mode == "row": + raise NotImplementedError( + "Normalization does not support tensor-parallel distribution." + ) if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size) From 0e225c17e8085529ffd539c15f7056c794e6f425 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Fri, 20 Feb 2026 22:46:31 +0000 Subject: [PATCH 2/2] Disable Userbuffers test for row-TP LayerNormLinear Signed-off-by: Tim Moon --- tests/pytorch/distributed/test_comm_gemm_overlap.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 95bf5aa05a..7a81f93bd6 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -210,7 +210,6 @@ def test_bulk_overlaps(comm_type, quantization, connections): (te.Linear.__name__, "row", False), (te.Linear.__name__, "column", False), (te.Linear.__name__, "column", True), - (te.LayerNormLinear.__name__, "row", False), (te.LayerNormLinear.__name__, "column", False), (te.LayerNormLinear.__name__, "column", True), ] @@ -225,7 +224,6 @@ def test_bulk_overlaps(comm_type, quantization, connections): f" {te.Linear.__name__} - ROW-PARALLEL ", f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", - f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", ] @@ -254,7 +252,6 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d (te.Linear.__name__, "row", False), (te.Linear.__name__, "column", False), (te.Linear.__name__, "column", True), - (te.LayerNormLinear.__name__, "row", False), (te.LayerNormLinear.__name__, "column", False), (te.LayerNormLinear.__name__, "column", True), ] @@ -269,7 +266,6 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d f"{te.Linear.__name__}-row_tensor_parallel", f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD", f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS", - f"{te.LayerNormLinear.__name__}-row_tensor_parallel", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS", ]