diff --git a/tests/pytorch/distributed/test_comm_gemm_overlap.py b/tests/pytorch/distributed/test_comm_gemm_overlap.py index 95bf5aa05a..7a81f93bd6 100644 --- a/tests/pytorch/distributed/test_comm_gemm_overlap.py +++ b/tests/pytorch/distributed/test_comm_gemm_overlap.py @@ -210,7 +210,6 @@ def test_bulk_overlaps(comm_type, quantization, connections): (te.Linear.__name__, "row", False), (te.Linear.__name__, "column", False), (te.Linear.__name__, "column", True), - (te.LayerNormLinear.__name__, "row", False), (te.LayerNormLinear.__name__, "column", False), (te.LayerNormLinear.__name__, "column", True), ] @@ -225,7 +224,6 @@ def test_bulk_overlaps(comm_type, quantization, connections): f" {te.Linear.__name__} - ROW-PARALLEL ", f" {te.Linear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f" {te.Linear.__name__} - COL-PARLALEL - DGRAD+RS ", - f" {te.LayerNormLinear.__name__} - ROW-PARALLEL ", f" {te.LayerNormLinear.__name__} - COL-PARALLEL - BULK DGRAD/WGRAD ", f" {te.LayerNormLinear.__name__} - COL-PARALLEL - DGRAD+RS ", ] @@ -254,7 +252,6 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d (te.Linear.__name__, "row", False), (te.Linear.__name__, "column", False), (te.Linear.__name__, "column", True), - (te.LayerNormLinear.__name__, "row", False), (te.LayerNormLinear.__name__, "column", False), (te.LayerNormLinear.__name__, "column", True), ] @@ -269,7 +266,6 @@ def test_layers_with_overlap_bf16(layer_type, linear_parallel_mode, overlap_rs_d f"{te.Linear.__name__}-row_tensor_parallel", f"{te.Linear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD", f"{te.Linear.__name__}-col_tensor_parallel-DGRAD+RS", - f"{te.LayerNormLinear.__name__}-row_tensor_parallel", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-BULK DGRAD/WGRAD", f"{te.LayerNormLinear.__name__}-col_tensor_parallel-DGRAD+RS", ] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..965d145823 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1193,6 +1193,10 @@ def __init__( assert ( self.parallel_mode in GemmParallelModes ), f"parallel_mode {parallel_mode} not supported" + if self.parallel_mode == "row": + raise NotImplementedError( + "Normalization does not support tensor-parallel distribution." + ) if self.parallel_mode == "column": self.out_features = divide(self.out_features, self.tp_size)