From f009c8910689453ab0154bfc3eae23584541f1ab Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 13 Feb 2026 04:46:43 +0000 Subject: [PATCH 1/2] Make grouped weights opt-in Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_sanity.py | 20 +++++++++++++++++-- .../pytorch/module/grouped_linear.py | 5 ++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 033a6a7ffb..220180d5c6 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -585,10 +585,19 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) +@pytest.mark.parametrize("single_param", all_boolean) @pytest.mark.parametrize("empty_split", ["first", "last", "middle"]) @pytest.mark.parametrize("num_gemms", [4]) def test_sanity_grouped_linear( - dtype, bs, model, fp8_recipe, fp8_model_params, use_bias, num_gemms, empty_split + dtype, + bs, + model, + fp8_recipe, + fp8_model_params, + use_bias, + single_param, + num_gemms, + empty_split, ): if NVTE_TEST_NVINSPECT_ENABLED and fp8_model_params: pytest.skip("FP8 model parameters are not supported in debug mode.") @@ -598,6 +607,9 @@ def test_sanity_grouped_linear( bs = bs * 16 num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) + if single_param: + os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] = "1" + if fp8_recipe is not None: if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -617,7 +629,8 @@ def test_sanity_grouped_linear( # Verify that weights are stored in contiguous GroupedTensor storage. weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)] if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): - check_grouped_tensor_pointers(weights, fp8_recipe) + if single_param: + check_grouped_tensor_pointers(weights, fp8_recipe) inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True @@ -636,6 +649,9 @@ def test_sanity_grouped_linear( loss.backward() assert out.shape == (num_tokens, ffn_hidden_size) + if single_param: + del os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] + @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b6596bc2e9..bf0e3d50a3 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -6,6 +6,7 @@ from typing import Union, Optional, Callable, Tuple, List from itertools import chain import warnings +import os import functools import torch @@ -793,7 +794,9 @@ def make_grouped_weights(self, defer_init=False) -> None: def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) - self.make_grouped_weights(defer_init=defer_init) + # Grouped tensor weights is an opt-in feature. + if bool(int(os.getenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0"))): + self.make_grouped_weights(defer_init=defer_init) def set_tensor_parallel_attributes(self, defer_init=False) -> None: """Set attributes needed for TP""" From 197f2e6b919ea867c5bf03eee4bc7e1fe05ac54a Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 13 Feb 2026 05:49:13 +0000 Subject: [PATCH 2/2] Change varname Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_sanity.py | 4 ++-- transformer_engine/pytorch/module/grouped_linear.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 220180d5c6..d47bc553b0 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -608,7 +608,7 @@ def test_sanity_grouped_linear( num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) if single_param: - os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] = "1" + os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] = "1" if fp8_recipe is not None: if not is_fp8_supported(config): @@ -650,7 +650,7 @@ def test_sanity_grouped_linear( assert out.shape == (num_tokens, ffn_hidden_size) if single_param: - del os.environ["NVTE_GROUPED_LINEAR_SINGLE_PARAM"] + del os.environ["NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS"] @pytest.mark.parametrize("dtype", param_types) diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index bf0e3d50a3..2f859e748b 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -795,7 +795,7 @@ def make_grouped_weights(self, defer_init=False) -> None: def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) # Grouped tensor weights is an opt-in feature. - if bool(int(os.getenv("NVTE_GROUPED_LINEAR_SINGLE_PARAM", "0"))): + if bool(int(os.getenv("NVTE_ALLOC_CONTIGUOUS_GROUPED_LINEAR_WEIGHTS", "0"))): self.make_grouped_weights(defer_init=defer_init) def set_tensor_parallel_attributes(self, defer_init=False) -> None: