Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
**_default_disabled_quantizer_cfg,
**_mamba_moe_disabled_quantizer_cfg,
},
"algorithm": "max",
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
}

MAMBA_MOE_FP8_CONSERVATIVE_CFG = {
Expand All @@ -226,7 +226,7 @@
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
},
"algorithm": "max",
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
}

FP8_PER_CHANNEL_PER_TOKEN_CFG = {
Expand Down Expand Up @@ -437,7 +437,7 @@
**_default_disabled_quantizer_cfg,
**_mamba_moe_disabled_quantizer_cfg,
},
"algorithm": "max",
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
}
MAMBA_MOE_NVFP4_CONSERVATIVE_CFG = {
"quant_cfg": {
Expand All @@ -458,7 +458,7 @@
"*mixer.in_proj*": {"enable": False}, # Skip mamba linear
"*mixer.out_proj*": {"enable": False}, # Skip mamba linear
},
"algorithm": "max",
"algorithm": {"method": "max", "shared_moe_weight_scale": False},
}


Expand Down Expand Up @@ -1087,6 +1087,12 @@ class MaxCalibConfig(QuantizeAlgorithmConfig):
description="If True, the amax will be synced across the distributed processes.",
)

shared_moe_weight_scale: bool | None = ModeloptField(
default=True,
title="Whether to share the weight scale across local experts.",
description="If True, the weight scale will be shared across local experts.",
)


class MseCalibConfig(QuantizeAlgorithmConfig):
"""Configuration for per-tensor MSE calibration.
Expand Down
11 changes: 9 additions & 2 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,20 @@ def _check_moe_calibration_complete(quantizer, parallel_state):


@torch.no_grad()
def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, distributed_sync=True):
def max_calibrate(
model: nn.Module,
forward_loop: ForwardLoop | None = None,
distributed_sync=True,
shared_moe_weight_scale=True,
):
"""Calibrate the model using max.

Args:
model: Model to be calibrated.
forward_loop: A callable which takes the model as argument and
forwards calibration data through the model.
distributed_sync: Whether to sync amax across distributed processes.
shared_moe_weight_scale: Whether to share the weight scale across local experts.

See :class:`MaxCalibConfig <modelopt.torch.quantization.config.MaxCalibConfig>` for
details on the remaining arguments.
Expand All @@ -116,7 +123,7 @@ def max_calibrate(model: nn.Module, forward_loop: ForwardLoop | None = None, dis
# Sync amax across local experts within each rank (for SequentialMLP)
for name, module in model.named_modules():
if hasattr(module, "layer_sync_moe_local_experts_amax"):
module.layer_sync_moe_local_experts_amax()
module.layer_sync_moe_local_experts_amax(shared_moe_weight_scale)

if not distributed_sync:
return
Expand Down
31 changes: 18 additions & 13 deletions modelopt/torch/quantization/plugins/megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,34 +574,39 @@ def _setup(self):
expert.linear_fc1.parallel_state = self.parallel_state
expert.linear_fc2.parallel_state = self.parallel_state

def layer_sync_moe_local_experts_amax(self):
"""Sync amax across local experts in a SequentialMLP.
def layer_sync_moe_local_experts_amax(self, shared_moe_weight_scale=True):
"""Sync input quantizer amax across local experts in a SequentialMLP, and optionally weight scale.

Synchronize the amax values across local experts in a lyaer such that all local experts will
share the same amax. This function operates on a single rank and does not require distributed sync.
Ensures all experts have the same input quantizer amax.This function operates
on a single rank and does not require distributed sync.

Distributed amax sync across EP and ETP (for RowParallel) happens in model_calib.max_calibrate().
This function should be called before the distributed sync to ensure the amax values
are synchronized across the layer first.

Note:
Because there are logic which calls collective communication based on whether amax is not None,
We need to garuantee that all experts must have amax. Otherwise, there will be deadlock
when synchroizing over EP since some ranks may have amax None and not calling the collective
We need to guarantee that all experts must have amax. Otherwise, there will be deadlock
when synchronizing over EP since some ranks may have amax None and not calling the collective
communication.

Args:
shared_moe_weight_scale: Whether to share the weight scale across local experts.
"""
# Collect amax from all local experts
amax_dict = {}
for expert in self.local_experts:
for name, module in expert.named_modules():
if isinstance(module, TensorQuantizer) and module.amax is not None:
stored_amax = amax_dict.get(name)
amax_tensor = module.amax.detach().clone()
amax_dict[name] = (
amax_tensor
if stored_amax is None
else torch.maximum(stored_amax, amax_tensor)
)
if shared_moe_weight_scale or ("weight_quantizer" not in name):
# Sync both quantizers or only sync input quantizer
stored_amax = amax_dict.get(name)
amax_tensor = module.amax.detach().clone()
amax_dict[name] = (
amax_tensor
if stored_amax is None
else torch.maximum(stored_amax, amax_tensor)
)

# Apply synchronized amax values back to all local experts
for expert in self.local_experts:
Expand Down
95 changes: 88 additions & 7 deletions tests/gpu_megatron/torch/quantization/plugins/test_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,10 +473,7 @@ def test_homogeneous_sharded_state_dict(tmp_path, config, compress, meta_device,

@pytest.mark.parametrize(
"config",
[
NVFP4_GEMM_KV_CFG,
FP8_GEMM_KV_CFG,
],
[NVFP4_GEMM_KV_CFG, FP8_GEMM_KV_CFG, mtq.MAMBA_MOE_NVFP4_CONSERVATIVE_CFG],
)
def test_homogeneous_sharded_state_dict_hybrid(tmp_path, config):
"""Test sharded state dict for hybrid Mamba MOE models."""
Expand Down Expand Up @@ -735,6 +732,93 @@ def test_te_grouped_vs_sequential_quantize(need_4_gpus):
)


@pytest.mark.parametrize("ep_size", [1, 2])
@pytest.mark.parametrize("moe_grouped_gemm", [True, False])
@pytest.mark.parametrize("shared_moe_weight_scale", [True, False])
def test_layer_sync_moe_local_experts_amax(ep_size, moe_grouped_gemm, shared_moe_weight_scale):
"""Test expert model parallel synchronization."""
size = torch.cuda.device_count()
if size < ep_size:
pytest.skip(f"Requires at least {ep_size} GPUs for expert model parallel test")

spawn_multiprocess_job(
size=size,
job=partial(
_test_layer_sync_moe_local_experts_amax,
ep_size,
moe_grouped_gemm,
shared_moe_weight_scale,
),
backend="nccl",
)


def _test_layer_sync_moe_local_experts_amax(
ep_size, moe_grouped_gemm, shared_moe_weight_scale, rank, size
):
initialize_for_megatron(
tensor_model_parallel_size=1,
pipeline_model_parallel_size=1,
expert_model_parallel_size=ep_size,
expert_tensor_parallel_size=1,
seed=SEED,
)
model = _gpt_model_provider(
tp_size=1,
ep_size=ep_size,
etp_size=1,
hidden_size=256,
moe_grouped_gemm=moe_grouped_gemm,
use_te=moe_grouped_gemm,
num_moe_experts=8,
transformer_impl="modelopt",
)
quant_cfg = mtq.FP8_DEFAULT_CFG
if not shared_moe_weight_scale:
quant_cfg = copy.deepcopy(quant_cfg)
quant_cfg["algorithm"] = {"method": "max", "shared_moe_weight_scale": False}
model = mtq.quantize(model, quant_cfg, get_forward(model))

# does layer_sync_moe_local_experts_amax happens in mtq.quantize if EP=1?
for layer in model.decoder.layers:
layer.mlp.experts.layer_sync_moe_local_experts_amax(shared_moe_weight_scale)

for layer in model.decoder.layers:
fc1_amax = None
fc2_amax = None
for expert in layer.mlp.experts.local_experts:
assert expert.linear_fc1.input_quantizer.amax is not None
assert expert.linear_fc2.input_quantizer.amax is not None
if fc1_amax is None:
fc1_amax = expert.linear_fc1.input_quantizer.amax
else:
assert torch.allclose(fc1_amax, expert.linear_fc1.input_quantizer.amax)
if fc2_amax is None:
fc2_amax = expert.linear_fc2.input_quantizer.amax
else:
assert torch.allclose(fc2_amax, expert.linear_fc2.input_quantizer.amax)

for layer in model.decoder.layers:
fc1_amax = None
fc2_amax = None
for expert in layer.mlp.experts.local_experts:
assert expert.linear_fc1.weight_quantizer.amax is not None
assert expert.linear_fc2.weight_quantizer.amax is not None
if fc1_amax is None:
fc1_amax = expert.linear_fc1.weight_quantizer.amax
elif shared_moe_weight_scale:
assert torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
else:
assert not torch.allclose(fc1_amax, expert.linear_fc1.weight_quantizer.amax)
fc1_amax = expert.linear_fc1.weight_quantizer.amax # update most recent amax

if fc2_amax is None:
fc2_amax = expert.linear_fc2.weight_quantizer.amax
elif shared_moe_weight_scale:
assert torch.allclose(fc2_amax, expert.linear_fc2.weight_quantizer.amax)
# FC2 amaxes are the same since the input to the layer is all the same


def _test_expert_model_parallel_amax_sync(
tp_size, ep_size, etp_size, moe_grouped_gemm, config, rank, size
):
Expand Down Expand Up @@ -815,9 +899,6 @@ def test_expert_parallel_sync(config, ep_size, etp_size, moe_grouped_gemm):
if size < ep_size * etp_size:
pytest.skip(f"Requires at least {ep_size * etp_size} GPUs for expert model parallel test")

if moe_grouped_gemm:
pytest.skip("TEGroupedMLP is not enabled in Megatron-LM currently")

spawn_multiprocess_job(
size=size,
job=partial(
Expand Down
Loading