diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py
new file mode 100644
index 000000000..04d736ebd
--- /dev/null
+++ b/modelopt/torch/export/moe_utils.py
@@ -0,0 +1,130 @@
+# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Utilities for Mixture-of-Experts (MoE) model export."""
+
+from pathlib import Path
+
+import torch.nn as nn
+
+
+def sync_expert_amax_low_tokens(model: nn.Module, threshold: float = 0.05):
+ """Sync expert amax values across experts if the number of tokens routed to an expert is less than the threshold.
+
+ For each MoE layer, this function collects the maximum amax value across all experts for
+ each input quantizer, then overwrites the amax of experts whose token count falls below
+ ``threshold * mean_token_count`` with that maximum.
+ """
+ for module_name, module in model.named_modules():
+ if not (hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0):
+ continue
+
+ experts = list(module.experts.children())
+ num_experts = module.expert_token_count.shape[0]
+
+ expert_amax_max = {}
+ for expert in experts:
+ for quantizer_name, quantizer in expert.named_modules():
+ # We do not sync amax for AWQ.
+ if hasattr(quantizer, "pre_quant_scale") and quantizer.pre_quant_scale is not None:
+ return
+
+ if (
+ "input_quantizer" in quantizer_name
+ and hasattr(quantizer, "_amax")
+ and quantizer._amax is not None
+ and quantizer._amax.numel() == 1
+ ):
+ prev = expert_amax_max.get(quantizer_name)
+ cur = quantizer._amax.detach().clone()
+ if prev is None or cur > prev:
+ expert_amax_max[quantizer_name] = cur
+
+ if not expert_amax_max:
+ continue
+
+ avg_token_count = module.expert_token_count.float().mean().item()
+ token_threshold = avg_token_count * threshold
+
+ print(f"[sync_expert_amax] {module_name}")
+ print(f" token counts : {module.expert_token_count.tolist()}")
+ print(f" avg={avg_token_count:.1f} threshold(<)={token_threshold:.1f}")
+ print(f" tracked quantizers: {list(expert_amax_max.keys())}")
+
+ for i in range(num_experts):
+ token_count_i = module.expert_token_count[i].item()
+ if token_count_i < token_threshold:
+ expert_i = experts[i]
+ print(f" expert {i}: token_count={token_count_i} — syncing amax")
+ for quantizer_name, quantizer in expert_i.named_modules():
+ if quantizer_name in expert_amax_max:
+ old_val = quantizer._amax.item() if quantizer._amax is not None else None
+ quantizer._amax = expert_amax_max[quantizer_name].clone()
+ print(f" {quantizer_name}: {old_val} -> {quantizer._amax.item()}")
+
+
+def save_expert_token_count_table(
+ model: nn.Module, output_dir: str | Path | None = None, threshold: float = 0.05
+):
+ """Collect expert_token_count from all quantized MoE layers and save as an HTML table.
+
+ The table has rows for each MoE layer and columns for each expert, with cell values
+ showing the number of tokens routed to that expert during calibration.
+
+ Args:
+ model: The model containing quantized MoE layers with ``expert_token_count`` attributes.
+ output_dir: Directory to save the HTML file. Defaults to current directory.
+ threshold: Threshold for low token count to sync amax. Defaults to 0.05.
+ """
+ rows = []
+ for name, module in model.named_modules():
+ if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0:
+ rows.append((name, module.expert_token_count))
+
+ if not rows:
+ return
+
+ num_experts = rows[0][1].shape[0]
+ html_parts = [
+ "
",
+ "Expert Token Counts (per MoE layer)
",
+ "| Layer/Expert | ",
+ ]
+ html_parts.extend(f"{i} | " for i in range(num_experts))
+ html_parts.append("
")
+
+ for name, counts in rows:
+ avg = counts.float().mean().item()
+ html_parts.append(f"| {name} | ")
+ for c in counts.tolist():
+ if avg > 0 and c < avg * threshold:
+ style = ' style="background: #ffcccc;"'
+ else:
+ style = ""
+ html_parts.append(f"{c} | ")
+ html_parts.append("
")
+
+ html_parts.append("
")
+ html_content = "\n".join(html_parts)
+
+ if output_dir is None:
+ output_dir = Path(".")
+ output_path = Path(output_dir) / "moe.html"
+ output_path.write_text(html_content)
+ print(f"Expert token count table saved to {output_path}")
diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py
index 447fc43a7..a5d948836 100644
--- a/modelopt/torch/export/unified_export_hf.py
+++ b/modelopt/torch/export/unified_export_hf.py
@@ -76,6 +76,7 @@
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .model_utils import get_language_model_from_vl, is_multimodal_model
+from .moe_utils import save_expert_token_count_table, sync_expert_amax_low_tokens
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
@@ -1003,6 +1004,9 @@ def export_hf_checkpoint(
try:
post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype)
+ save_expert_token_count_table(model, export_dir)
+ sync_expert_amax_low_tokens(model)
+
if hf_quant_config is not None:
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py
index 807c92c2c..69f56eef2 100644
--- a/modelopt/torch/quantization/plugins/huggingface.py
+++ b/modelopt/torch/quantization/plugins/huggingface.py
@@ -450,20 +450,58 @@ class _QuantSparseMoe(QuantModule):
"""
def _setup(self):
- pass
+ num_experts = 0
+ if hasattr(self, "gate") and hasattr(self.gate, "num_experts"):
+ num_experts = self.gate.num_experts
+ elif hasattr(self, "num_experts"):
+ num_experts = self.num_experts
+ elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"):
+ num_experts = self.experts.num_experts
+
+ self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
+ self._count_expert_tokens = False
+
+ if hasattr(self, "gate"):
+ self.gate.register_forward_hook(self._gate_forward_hook)
+
+ def _gate_forward_hook(self, module, input, output):
+ if not self._count_expert_tokens:
+ return
+ with torch.no_grad():
+ if isinstance(output, tuple) and len(output) >= 3:
+ # v5.x TopKRouter: returns (logits, scores, indices)
+ indices = output[2]
+ else:
+ # v4.x nn.Linear gate: returns logits tensor
+ logits = output if not isinstance(output, tuple) else output[0]
+ top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k
+ _, indices = torch.topk(logits.float(), top_k, dim=-1)
+ counts = torch.bincount(
+ indices.reshape(-1).cpu(), minlength=len(self.expert_token_count)
+ )
+ self.expert_token_count += counts
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
- if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
- # If any of the experts are in calibration mode, we will forward all tokens to all experts
+ is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
+ has_prequant_scale = any(
+ getattr(m, "awq_lite", None) is not None or getattr(m, "awq_clip", None) is not None
+ for m in self.experts.modules()
+ )
+ self._count_expert_tokens = is_calib
+ import pdb
+
+ pdb.set_trace()
+ if is_calib and has_prequant_scale:
+ # If any of the experts are in AWQ calibration mode, we will forward all tokens to all experts
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate")
# Path for transformers >= 5.0
- original_top_k = self.gate.topk
- self.gate.topk = self.gate.num_experts
+ original_top_k = self.gate.top_k
+ self.gate.top_k = self.gate.num_experts
super().forward(hidden_states)
- self.gate.topk = original_top_k
+ self.gate.top_k = original_top_k
else:
# Path for transformers < 5.0
original_top_k = self.top_k
@@ -475,7 +513,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
self.top_k = original_top_k
- return super().forward(hidden_states)
+ # Enable counting only for the real-routing forward during calibration
+ output = super().forward(hidden_states)
+ self._count_expert_tokens = False
+ return output
class _QuantLlama4TextExperts(QuantModule):
@@ -765,10 +806,7 @@ def unpack_weight(self):
try:
- from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe
-
- if Llama4TextMoe not in QuantModuleRegistry:
- QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)
+ from transformers.models.llama4.modeling_llama4 import Llama4TextExperts
if Llama4TextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
@@ -791,16 +829,6 @@ def unpack_weight(self):
except ImportError:
pass
-try:
- from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
-
- if MixtralSparseMoeBlock not in QuantModuleRegistry:
- QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
- _QuantSparseMoe
- )
-except ImportError:
- pass
-
try:
from transformers.models.falcon.modeling_falcon import FalconLinear
@@ -809,36 +837,6 @@ def unpack_weight(self):
except ImportError:
pass
-try:
- from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock
-
- if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
- QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
- _QuantSparseMoe
- )
-except ImportError:
- pass
-
-try:
- from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock
-
- if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
- QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
- _QuantSparseMoe
- )
-except ImportError:
- pass
-
-try:
- from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock
-
- if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
- QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
- _QuantSparseMoe
- )
-except ImportError:
- pass
-
try:
from compressed_tensors.linear.compressed_linear import CompressedLinear
@@ -850,15 +848,7 @@ def unpack_weight(self):
pass
try:
- from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
- Qwen3VLMoeTextExperts,
- Qwen3VLMoeTextSparseMoeBlock,
- )
-
- if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry:
- QuantModuleRegistry.register(
- {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"}
- )(_QuantSparseMoe)
+ from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts
if Qwen3VLMoeTextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})(
@@ -989,15 +979,55 @@ def register_falcon_linears_on_the_fly(model):
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)
-def register_minimax_m2_moe_on_the_fly(model):
- """Register MiniMax M2 MoE modules as a QUANT_MODULE.
+def _is_sparse_moe_block(module):
+ """Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe.
+
+ All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.)
+ share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes
+ (``top_k`` and ``num_experts``), and an ``experts`` sub-module.
- MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly.
+ This function detects that pattern instead of relying on class names, making it forward-compatible
+ with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but
+ use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom
+ ``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives.
"""
- if type(model).__name__ in ["MiniMaxM2ForCausalLM"]:
- moe_type = type(model.model.layers[0].block_sparse_moe)
- if QuantModuleRegistry.get(moe_type) is None:
- QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe)
+ if not hasattr(module, "experts"):
+ return False
+
+ # Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern)
+ if hasattr(module, "gate"):
+ gate = module.gate
+ has_topk = hasattr(gate, "top_k")
+ has_num_experts = hasattr(gate, "num_experts")
+ if has_topk and has_num_experts:
+ return True
+
+ # Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next)
+ return hasattr(module, "top_k") and hasattr(module, "num_experts")
+
+
+def register_sparse_moe_on_the_fly(model):
+ """Auto-detect and register MOE modules as _QuantSparseMoe.
+
+ Walks the model tree, identifies MoE blocks by their structural attributes
+ (``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``.
+ """
+ registered_types = set()
+ for name, module in model.named_modules():
+ mod_type = type(module)
+
+ # Avoid duplicate registration: skip if we already processed this type
+ # in this walk, or if it was previously registered in the QuantModuleRegistry.
+ if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None:
+ continue
+
+ if _is_sparse_moe_block(module):
+ print(
+ f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, "
+ f"registering with _QuantSparseMoe.\033[0m"
+ )
+ QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe)
+ registered_types.add(mod_type)
def _is_supported_hf_model(model):
@@ -1065,7 +1095,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
- register_minimax_m2_moe_on_the_fly,
+ register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
]