From 28d5686e40784462bf00567feccb89a7fcfacc63 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 18 Feb 2026 20:06:24 +0000 Subject: [PATCH 1/4] sync moe input quantizer only Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/model_calib.py | 11 ++- .../torch/quantization/plugins/megatron.py | 31 ++++--- .../quantization/plugins/test_megatron.py | 82 ++++++++++++++++++- 3 files changed, 106 insertions(+), 18 deletions(-) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 6455447ac..081aa6f1c 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -95,13 +95,20 @@ def _check_moe_calibration_complete(quantizer, parallel_state): @torch.no_grad() -def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True): +def max_calibrate( + model: nn.Module, + forward_loop: ForwardLoop | None = None, + distributed_sync=True, + shared_moe_weight_scale=True, +): """Calibrate the model using max. Args: model: Model to be calibrated. forward_loop: A callable which takes the model as argument and forwards calibration data through the model. + distributed_sync: Whether to sync amax across distributed processes. + shared_moe_weight_scale: Whether to share the weight scale across local experts. See :class:`MaxCalibConfig ` for details on the remaining arguments. @@ -116,7 +123,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis # Sync amax across local experts within each rank (for SequentialMLP) for name, module in model.named_modules(): if hasattr(module, "layer_sync_moe_local_experts_amax"): - module.layer_sync_moe_local_experts_amax() + module.layer_sync_moe_local_experts_amax(shared_moe_weight_scale) if not distributed_sync: return diff --git a/modelopt/torch/quantization/plugins/megatron.py b/modelopt/torch/quantization/plugins/megatron.py index 6e92fce90..adfb94edd 100644 --- a/modelopt/torch/quantization/plugins/megatron.py +++ b/modelopt/torch/quantization/plugins/megatron.py @@ -574,11 +574,11 @@ def _setup(self): expert.linear_fc1.parallel_state = self.parallel_state expert.linear_fc2.parallel_state = self.parallel_state - def layer_sync_moe_local_experts_amax(self): - """Sync amax across local experts in a SequentialMLP. + def layer_sync_moe_local_experts_amax(self, shared_moe_weight_scale=True): + """Sync input quantizer amax across local experts in a SequentialMLP, and optionally weight scale. - Synchronize the amax values across local experts in a lyaer such that all local experts will - share the same amax. This function operates on a single rank and does not require distributed sync. + Ensures all experts have the same input quantizer amax.This function operates + on a single rank and does not require distributed sync. Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate(). This function should be called before the distributed sync to ensure the amax values @@ -586,22 +586,27 @@ def layer_sync_moe_local_experts_amax(self): Note: Because there are logic which calls collective communication based on whether amax is not None, - We need to garuantee that all experts must have amax. Otherwise, there will be deadlock - when synchroizing over EP since some ranks may have amax None and not calling the collective + We need to guarantee that all experts must have amax. Otherwise, there will be deadlock + when synchronizing over EP since some ranks may have amax None and not calling the collective communication. + + Args: + shared_moe_weight_scale: Whether to share the weight scale across local experts. """ # Collect amax from all local experts amax_dict = {} for expert in self.local_experts: for name, module in expert.named_modules(): if isinstance(module, TensorQuantizer) and module.amax is not None: - stored_amax = amax_dict.get(name) - amax_tensor = module.amax.detach().clone() - amax_dict[name] = ( - amax_tensor - if stored_amax is None - else torch.maximum(stored_amax, amax_tensor) - ) + if shared_moe_weight_scale or ("weight_quantizer" not in name): + # Sync both quantizers or only sync input quantizer + stored_amax = amax_dict.get(name) + amax_tensor = module.amax.detach().clone() + amax_dict[name] = ( + amax_tensor + if stored_amax is None + else torch.maximum(stored_amax, amax_tensor) + ) # Apply synchronized amax values back to all local experts for expert in self.local_experts: diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index b107eca71..cbe736dc6 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -735,6 +735,85 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus): ) +@pytest.mark.parametrize("ep_size", [1, 2]) +@pytest.mark.parametrize("moe_grouped_gemm", [True, False]) +@pytest.mark.parametrize("shared_moe_weight_scale", [True, False]) +def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, shared_moe_weight_scale): + """Test expert model parallel synchronization.""" + size = torch.cuda.device_count() + if size < ep_size: + pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test") + + spawn_multiprocess_job( + size=size, + job=partial( + _test_layer_sync_moe_local_experts_amax, + ep_size, + moe_grouped_gemm, + shared_moe_weight_scale, + ), + backend="nccl", + ) + + +def _test_layer_sync_moe_local_experts_amax( + ep_size, moe_grouped_gemm, shared_moe_weight_scale, rank, size +): + initialize_for_megatron( + tensor_model_parallel_size=1, + pipeline_model_parallel_size=1, + expert_model_parallel_size=ep_size, + expert_tensor_parallel_size=1, + seed=SEED, + ) + model = _gpt_model_provider( + tp_size=1, + ep_size=ep_size, + etp_size=1, + hidden_size=256, + moe_grouped_gemm=moe_grouped_gemm, + use_te=moe_grouped_gemm, + num_moe_experts=8, + transformer_impl="modelopt", + ) + model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, get_forward(model)) + + # Sync amax across local experts in each layer + for layer in model.decoder.layers: + layer.mlp.experts.layer_sync_moe_local_experts_amax(shared_moe_weight_scale) + + for layer in model.decoder.layers: + fc1_amax = None + fc2_amax = None + for expert in layer.mlp.experts.local_experts: + assert expert.linear_fc1.input_quantizer.amax is not None + assert expert.linear_fc2.input_quantizer.amax is not None + if fc1_amax is None: + fc1_amax = expert.linear_fc1.input_quantizer.amax + else: + assert torch.allclose(fc1_amax, expert.linear_fc1.input_quantizer.amax) + if fc2_amax is None: + fc2_amax = expert.linear_fc2.input_quantizer.amax + else: + assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax) + + if shared_moe_weight_scale: + for layer in model.decoder.layers: + fc1_amax = None + fc2_amax = None + for expert in layer.mlp.experts.local_experts: + assert expert.linear_fc1.weight_quantizer.amax is not None + assert expert.linear_fc2.weight_quantizer.amax is not None + if fc1_amax is None: + fc1_amax = expert.linear_fc1.weight_quantizer.amax + else: + assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax) + if fc2_amax is None: + fc2_amax = expert.linear_fc2.weight_quantizer.amax + else: + assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) + + def _test_expert_model_parallel_amax_sync( tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size ): @@ -815,9 +894,6 @@ def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm): if size < ep_size * etp_size: pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test") - if moe_grouped_gemm: - pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently") - spawn_multiprocess_job( size=size, job=partial( From f74cd24eac35c2a1dea17f58fef386297c93717d Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 18 Feb 2026 20:16:35 +0000 Subject: [PATCH 2/4] update configs with shared_moe_weight_scale Signed-off-by: Jennifer Chen --- modelopt/torch/quantization/config.py | 14 ++++++++++---- .../torch/quantization/plugins/test_megatron.py | 5 +---- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 291acba03..171f6181f 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -214,7 +214,7 @@ **_default_disabled_quantizer_cfg, **_mamba_moe_disabled_quantizer_cfg, }, - "algorithm": "max", + "algorithm": {"method": "max", "shared_moe_weight_scale": False}, } MAMBA_MOE_FP8_CONSERVATIVE_CFG = { @@ -226,7 +226,7 @@ "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, - "algorithm": "max", + "algorithm": {"method": "max", "shared_moe_weight_scale": False}, } FP8_PER_CHANNEL_PER_TOKEN_CFG = { @@ -437,7 +437,7 @@ **_default_disabled_quantizer_cfg, **_mamba_moe_disabled_quantizer_cfg, }, - "algorithm": "max", + "algorithm": {"method": "max", "shared_moe_weight_scale": False}, } MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = { "quant_cfg": { @@ -458,7 +458,7 @@ "*mixer.in_proj*": {"enable": False}, # Skip mamba linear "*mixer.out_proj*": {"enable": False}, # Skip mamba linear }, - "algorithm": "max", + "algorithm": {"method": "max", "shared_moe_weight_scale": False}, } @@ -1087,6 +1087,12 @@ class MaxCalibConfig(QuantizeAlgorithmConfig): description="If True, the amax will be synced across the distributed processes.", ) + shared_moe_weight_scale: bool | None = ModeloptField( + default=True, + title="Whether to share the weight scale across local experts.", + description="If True, the weight scale will be shared across local experts.", + ) + class MseCalibConfig(QuantizeAlgorithmConfig): """Configuration for per-tensor MSE calibration. diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index cbe736dc6..49cbfba48 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -473,10 +473,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device, @pytest.mark.parametrize( "config", - [ - NVFP4_GEMM_KV_CFG, - FP8_GEMM_KV_CFG, - ], + [NVFP4_GEMM_KV_CFG, FP8_GEMM_KV_CFG, mtq.MAMBA_MOE_NVFP4_CONSERVATIVE_CFG], ) def test_homogeneous_sharded_state_dict_hybrid(tmp_path, config): """Test sharded state dict for hybrid Mamba MOE models.""" From 002e6c6f134cf6d99b60913f698a6c1459161de1 Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 18 Feb 2026 20:32:47 +0000 Subject: [PATCH 3/4] check when weight quantizers not same Signed-off-by: Jennifer Chen --- .../quantization/plugins/test_megatron.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index 49cbfba48..8ee0d40a6 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -794,21 +794,27 @@ def _test_layer_sync_moe_local_experts_amax( else: assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax) - if shared_moe_weight_scale: - for layer in model.decoder.layers: - fc1_amax = None - fc2_amax = None - for expert in layer.mlp.experts.local_experts: - assert expert.linear_fc1.weight_quantizer.amax is not None - assert expert.linear_fc2.weight_quantizer.amax is not None - if fc1_amax is None: - fc1_amax = expert.linear_fc1.weight_quantizer.amax - else: - assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax) - if fc2_amax is None: - fc2_amax = expert.linear_fc2.weight_quantizer.amax - else: - assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) + for layer in model.decoder.layers: + fc1_amax = None + fc2_amax = None + for expert in layer.mlp.experts.local_experts: + assert expert.linear_fc1.weight_quantizer.amax is not None + assert expert.linear_fc2.weight_quantizer.amax is not None + if fc1_amax is None: + fc1_amax = expert.linear_fc1.weight_quantizer.amax + elif shared_moe_weight_scale: + assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax) + else: + assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax) + fc1_amax = expert.linear_fc1.weight_quantizer.amax # update most recent amax + + if fc2_amax is None: + fc2_amax = expert.linear_fc2.weight_quantizer.amax + elif shared_moe_weight_scale: + assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) + else: + assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) + fc2_amax = expert.linear_fc2.weight_quantizer.amax def _test_expert_model_parallel_amax_sync( From 023b0a310a62069a0effd6363565b8d83a6e8f3b Mon Sep 17 00:00:00 2001 From: Jennifer Chen Date: Wed, 18 Feb 2026 21:22:28 +0000 Subject: [PATCH 4/4] fix test Signed-off-by: Jennifer Chen --- .../torch/quantization/plugins/test_megatron.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py index 8ee0d40a6..bad16a08f 100644 --- a/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py +++ b/tests/gpu_megatron/torch/quantization/plugins/test_megatron.py @@ -773,9 +773,13 @@ def _test_layer_sync_moe_local_experts_amax( num_moe_experts=8, transformer_impl="modelopt", ) - model = mtq.quantize(model, mtq.FP8_DEFAULT_CFG, get_forward(model)) + quant_cfg = mtq.FP8_DEFAULT_CFG + if not shared_moe_weight_scale: + quant_cfg = copy.deepcopy(quant_cfg) + quant_cfg["algorithm"] = {"method": "max", "shared_moe_weight_scale": False} + model = mtq.quantize(model, quant_cfg, get_forward(model)) - # Sync amax across local experts in each layer + # does layer_sync_moe_local_experts_amax happens in mtq.quantize if EP=1? for layer in model.decoder.layers: layer.mlp.experts.layer_sync_moe_local_experts_amax(shared_moe_weight_scale) @@ -812,9 +816,7 @@ def _test_layer_sync_moe_local_experts_amax( fc2_amax = expert.linear_fc2.weight_quantizer.amax elif shared_moe_weight_scale: assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) - else: - assert not torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax) - fc2_amax = expert.linear_fc2.weight_quantizer.amax + # FC2 amaxes are the same since the input to the layer is all the same def _test_expert_model_parallel_amax_sync(