From 919be0f6338e445b30cf83b1ec47984b1f5ee2eb Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Tue, 17 Feb 2026 22:44:16 +0000 Subject: [PATCH 1/8] Auto detect MOE layers Signed-off-by: Chenjie Luo --- .../torch/quantization/plugins/huggingface.py | 117 ++++++++---------- 1 file changed, 53 insertions(+), 64 deletions(-) diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 807c92c2c..6aed8c439 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -460,10 +460,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if TRANSFORMERS_VERSION_GE_5_0: assert hasattr(self, "gate") # Path for transformers >= 5.0 - original_top_k = self.gate.topk - self.gate.topk = self.gate.num_experts + original_top_k = self.gate.top_k + self.gate.top_k = self.gate.num_experts super().forward(hidden_states) - self.gate.topk = original_top_k + self.gate.top_k = original_top_k else: # Path for transformers < 5.0 original_top_k = self.top_k @@ -765,10 +765,7 @@ def unpack_weight(self): try: - from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe - - if Llama4TextMoe not in QuantModuleRegistry: - QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe) + from transformers.models.llama4.modeling_llama4 import Llama4TextExperts if Llama4TextExperts not in QuantModuleRegistry: QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})( @@ -791,16 +788,6 @@ def unpack_weight(self): except ImportError: pass -try: - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - - if MixtralSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - try: from transformers.models.falcon.modeling_falcon import FalconLinear @@ -809,36 +796,6 @@ def unpack_weight(self): except ImportError: pass -try: - from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock - - if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - -try: - from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock - - if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - -try: - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock - - if Qwen3NextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - try: from compressed_tensors.linear.compressed_linear import CompressedLinear @@ -850,15 +807,7 @@ def unpack_weight(self): pass try: - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( - Qwen3VLMoeTextExperts, - Qwen3VLMoeTextSparseMoeBlock, - ) - - if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register( - {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"} - )(_QuantSparseMoe) + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts if Qwen3VLMoeTextExperts not in QuantModuleRegistry: QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( @@ -989,15 +938,55 @@ def register_falcon_linears_on_the_fly(model): QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) -def register_minimax_m2_moe_on_the_fly(model): - """Register MiniMax M2 MoE modules as a QUANT_MODULE. +def _is_sparse_moe_block(module): + """Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe. + + All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.) + share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes + (``top_k`` and ``num_experts``), and an ``experts`` sub-module. - MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly. + This function detects that pattern instead of relying on class names, making it forward-compatible + with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but + use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom + ``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives. """ - if type(model).__name__ in ["MiniMaxM2ForCausalLM"]: - moe_type = type(model.model.layers[0].block_sparse_moe) - if QuantModuleRegistry.get(moe_type) is None: - QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe) + if not hasattr(module, "experts"): + return False + + # Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern) + if hasattr(module, "gate"): + gate = module.gate + has_topk = hasattr(gate, "top_k") + has_num_experts = hasattr(gate, "num_experts") + if has_topk and has_num_experts: + return True + + # Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next) + return hasattr(module, "top_k") and hasattr(module, "num_experts") + + +def register_sparse_moe_on_the_fly(model): + """Auto-detect and register MOE modules as _QuantSparseMoe. + + Walks the model tree, identifies MoE blocks by their structural attributes + (``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``. + """ + registered_types = set() + for name, module in model.named_modules(): + mod_type = type(module) + + # Avoid duplicate registration: skip if we already processed this type + # in this walk, or if it was previously registered in the QuantModuleRegistry. + if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None: + continue + + if _is_sparse_moe_block(module): + print( + f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, " + f"registering with _QuantSparseMoe.\033[0m" + ) + QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe) + registered_types.add(mod_type) def _is_supported_hf_model(model): @@ -1065,7 +1054,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): [ register_falcon_linears_on_the_fly, register_dbrx_moe_on_the_fly, - register_minimax_m2_moe_on_the_fly, + register_sparse_moe_on_the_fly, register_hf_attentions_on_the_fly, convert_hf_parallel_linears_on_the_fly, ] From 8baeaafae198b10ecc6a0f518914c4fc26fd248d Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 18 Feb 2026 00:46:17 +0000 Subject: [PATCH 2/8] Report MOE tokens Signed-off-by: Chenjie Luo --- modelopt/torch/export/moe_utils.py | 76 +++++++++++++++++++ modelopt/torch/export/unified_export_hf.py | 3 + .../torch/quantization/plugins/huggingface.py | 40 +++++++++- 3 files changed, 116 insertions(+), 3 deletions(-) create mode 100644 modelopt/torch/export/moe_utils.py diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py new file mode 100644 index 000000000..b040bed3a --- /dev/null +++ b/modelopt/torch/export/moe_utils.py @@ -0,0 +1,76 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Mixture-of-Experts (MoE) model export.""" + +from pathlib import Path + +import torch.nn as nn + + +def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): + """Collect expert_token_count from all quantized MoE layers and save as an HTML table. + + The table has rows for each MoE layer and columns for each expert, with cell values + showing the number of tokens routed to that expert during calibration. + + Args: + model: The model containing quantized MoE layers with ``expert_token_count`` attributes. + output_dir: Directory to save the HTML file. Defaults to current directory. + """ + rows = [] + for name, module in model.named_modules(): + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: + rows.append((name, module.expert_token_count)) + + if not rows: + return + + num_experts = rows[0][1].shape[0] + html_parts = [ + "", + "

Expert Token Counts (per MoE layer)

", + "", + ] + html_parts.extend(f"" for i in range(num_experts)) + html_parts.append("") + + for name, counts in rows: + avg = counts.float().mean().item() + html_parts.append(f"") + for c in counts.tolist(): + if avg > 0 and c < avg: + # Scale from white (at average) to full red (at zero) + ratio = c / avg + r_channel = 255 + gb_channel = int(100 * ratio) + style = f' style="background: rgb({r_channel},{gb_channel},{gb_channel});"' + else: + style = "" + html_parts.append(f"{c}") + html_parts.append("") + + html_parts.append("
Layer/Expert{i}
{name}
") + html_content = "\n".join(html_parts) + + if output_dir is None: + output_dir = Path(".") + output_path = Path(output_dir) / "moe.html" + output_path.write_text(html_content) + print(f"Expert token count table saved to {output_path}") diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 447fc43a7..99987eeb3 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -76,6 +76,7 @@ QUANTIZATION_W4A8_NVFP4_FP8, ) from .model_utils import get_language_model_from_vl, is_multimodal_model +from .moe_utils import save_expert_token_count_table from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only from .quant_utils import ( fuse_prequant_layernorm, @@ -1003,6 +1004,8 @@ def export_hf_checkpoint( try: post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) + save_expert_token_count_table(model, export_dir) + if hf_quant_config is not None: # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 6aed8c439..423599398 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -450,10 +450,40 @@ class _QuantSparseMoe(QuantModule): """ def _setup(self): - pass + num_experts = 0 + if hasattr(self, "gate") and hasattr(self.gate, "num_experts"): + num_experts = self.gate.num_experts + elif hasattr(self, "num_experts"): + num_experts = self.num_experts + elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): + num_experts = self.experts.num_experts + + self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu") + self._count_expert_tokens = False + + if hasattr(self, "gate"): + self.gate.register_forward_hook(self._gate_forward_hook) + + def _gate_forward_hook(self, module, input, output): + if not self._count_expert_tokens: + return + with torch.no_grad(): + if isinstance(output, tuple) and len(output) >= 3: + # v5.x TopKRouter: returns (logits, scores, indices) + indices = output[2] + else: + # v4.x nn.Linear gate: returns logits tensor + logits = output if not isinstance(output, tuple) else output[0] + top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k + _, indices = torch.topk(logits.float(), top_k, dim=-1) + counts = torch.bincount( + indices.reshape(-1).cpu(), minlength=len(self.expert_token_count) + ) + self.expert_token_count += counts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): + is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) + if is_calib: # If any of the experts are in calibration mode, we will forward all tokens to all experts # This is used only for calibration, we need to re-calculate the actual outputs again using # the original top_k @@ -475,7 +505,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: raise ValueError(f"Could not find num_experts in module {self}") super().forward(hidden_states) self.top_k = original_top_k - return super().forward(hidden_states) + # Enable counting only for the real-routing forward during calibration + self._count_expert_tokens = is_calib + output = super().forward(hidden_states) + self._count_expert_tokens = False + return output class _QuantLlama4TextExperts(QuantModule): From 7da77b960b32d387c3184bcfecdce3ec5167a359 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 18 Feb 2026 07:53:24 +0000 Subject: [PATCH 3/8] Update moe activation logics Signed-off-by: Chenjie Luo --- modelopt/torch/export/moe_utils.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index b040bed3a..87c2e1533 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -55,12 +55,10 @@ def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | Non avg = counts.float().mean().item() html_parts.append(f"{name}") for c in counts.tolist(): - if avg > 0 and c < avg: - # Scale from white (at average) to full red (at zero) - ratio = c / avg - r_channel = 255 - gb_channel = int(100 * ratio) - style = f' style="background: rgb({r_channel},{gb_channel},{gb_channel});"' + if avg > 0 and c < avg * 0.05: + style = ' style="background: #ff6666;"' + elif avg > 0 and c < avg * 0.1: + style = ' style="background: #ffcccc;"' else: style = "" html_parts.append(f"{c}") @@ -71,6 +69,6 @@ def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | Non if output_dir is None: output_dir = Path(".") - output_path = Path(output_dir) / "moe.html" + output_path = Path(output_dir) / ".moe.html" output_path.write_text(html_content) print(f"Expert token count table saved to {output_path}") From 2e29ee729ccb18da880699d0304689e352566f3d Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 18 Feb 2026 08:02:25 +0000 Subject: [PATCH 4/8] Add unittest Signed-off-by: Chenjie Luo --- .../quantization/plugins/test_sparse_moe.py | 317 ++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 tests/unit/torch/quantization/plugins/test_sparse_moe.py diff --git a/tests/unit/torch/quantization/plugins/test_sparse_moe.py b/tests/unit/torch/quantization/plugins/test_sparse_moe.py new file mode 100644 index 000000000..af4c9e8fe --- /dev/null +++ b/tests/unit/torch/quantization/plugins/test_sparse_moe.py @@ -0,0 +1,317 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _is_sparse_moe_block and _QuantSparseMoe.""" + +import pytest +import torch +import torch.nn as nn + +pytest.importorskip("transformers") + +from _test_utils.torch.transformers_models import get_tiny_qwen3_moe + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.nn import QuantModuleRegistry +from modelopt.torch.quantization.plugins.huggingface import ( + TRANSFORMERS_VERSION_GE_5_0, + _is_sparse_moe_block, + register_sparse_moe_on_the_fly, +) + + +# --------------------------------------------------------------------------- +# Helpers: lightweight mock modules for _is_sparse_moe_block +# --------------------------------------------------------------------------- +class _FakeGateWithRouter(nn.Module): + """Mimics a v5.x TopKRouter gate with top_k and num_experts.""" + + def __init__(self, top_k=2, num_experts=4): + super().__init__() + self.top_k = top_k + self.num_experts = num_experts + self.linear = nn.Linear(8, num_experts) + + def forward(self, x): + return self.linear(x) + + +class _FakeExperts(nn.ModuleList): + def __init__(self, n=4): + super().__init__([nn.Linear(8, 8) for _ in range(n)]) + self.num_experts = n + + +class _MoEBlockWithGateRouter(nn.Module): + """Matches the primary detection path: gate.top_k + gate.num_experts.""" + + def __init__(self, num_experts=4, top_k=2): + super().__init__() + self.gate = _FakeGateWithRouter(top_k=top_k, num_experts=num_experts) + self.experts = _FakeExperts(num_experts) + + def forward(self, hidden_states): + logits = self.gate(hidden_states) + routing_weights, selected = torch.topk(logits, self.gate.top_k, dim=-1) + out = torch.zeros_like(hidden_states) + for i in range(self.gate.num_experts): + mask = (selected == i).any(dim=-1) + if mask.any(): + out[mask] += self.experts[i](hidden_states[mask]) + return out + + +class _MoEBlockFallback(nn.Module): + """Matches the fallback path: top_k + num_experts on the block itself.""" + + def __init__(self, num_experts=4, top_k=2): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(8, num_experts) + self.experts = _FakeExperts(num_experts) + + def forward(self, hidden_states): + logits = self.gate(hidden_states) + routing_weights, selected = torch.topk(logits, self.top_k, dim=-1) + out = torch.zeros_like(hidden_states) + for i in range(self.num_experts): + mask = (selected == i).any(dim=-1) + if mask.any(): + out[mask] += self.experts[i](hidden_states[mask]) + return out + + +# --------------------------------------------------------------------------- +# Tests for _is_sparse_moe_block +# --------------------------------------------------------------------------- +class TestIsSparseBlock: + def test_no_experts_returns_false(self): + module = nn.Linear(8, 8) + assert _is_sparse_moe_block(module) is False + + def test_experts_but_no_gate_or_topk_returns_false(self): + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + assert _is_sparse_moe_block(module) is False + + def test_gate_with_router_attrs_returns_true(self): + block = _MoEBlockWithGateRouter(num_experts=4, top_k=2) + assert _is_sparse_moe_block(block) is True + + def test_fallback_block_level_attrs_returns_true(self): + block = _MoEBlockFallback(num_experts=4, top_k=2) + assert _is_sparse_moe_block(block) is True + + def test_gate_missing_num_experts_returns_false(self): + """gate.top_k present but gate.num_experts absent -> primary path fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.top_k = 2 + module.gate = gate + assert _is_sparse_moe_block(module) is False + + def test_gate_missing_top_k_returns_false(self): + """gate.num_experts present but gate.top_k absent -> primary path fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.num_experts = 4 + module.gate = gate + assert _is_sparse_moe_block(module) is False + + def test_block_level_only_top_k_returns_false(self): + """Only top_k on block (no num_experts) -> fallback fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + module.top_k = 2 + assert _is_sparse_moe_block(module) is False + + def test_block_level_only_num_experts_returns_false(self): + """Only num_experts on block (no top_k) -> fallback fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + module.num_experts = 4 + assert _is_sparse_moe_block(module) is False + + def test_glm4_like_block_rejected(self): + """A module with n_routed_experts instead of num_experts should be rejected.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.top_k = 2 + gate.n_routed_experts = 4 # different attr name + module.gate = gate + assert _is_sparse_moe_block(module) is False + + +# --------------------------------------------------------------------------- +# Tests for _QuantSparseMoe +# --------------------------------------------------------------------------- +class TestQuantSparseMoe: + """Tests for _QuantSparseMoe using a real tiny Qwen3Moe model.""" + + @staticmethod + def _get_moe_block(model): + """Return the first MoE block from the model.""" + for module in model.modules(): + if _is_sparse_moe_block(module): + return module + raise RuntimeError("No MoE block found in model") + + def test_register_sparse_moe_on_the_fly(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is not None: + pytest.skip("MoE type already registered (upstream change)") + + register_sparse_moe_on_the_fly(model) + assert QuantModuleRegistry.get(moe_type) is not None + + def test_setup_creates_expert_token_count(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + assert hasattr(converted, "expert_token_count") + expected_num_experts = moe_block.num_experts if hasattr(moe_block, "num_experts") else 0 + assert converted.expert_token_count.shape == (expected_num_experts,) + assert converted.expert_token_count.dtype == torch.long + assert (converted.expert_token_count == 0).all() + + def test_setup_count_expert_tokens_default_false(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + assert converted._count_expert_tokens is False + + def test_forward_no_calib_matches_original(self): + """When calibration is off, _QuantSparseMoe should produce the same output as the original.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + ref_block = self._get_moe_block(get_tiny_qwen3_moe()) + ref_block.load_state_dict(moe_block.state_dict()) + + converted = QuantModuleRegistry.convert(moe_block) + + torch.manual_seed(42) + x = torch.randn(1, 4, 32) + with torch.no_grad(): + out_ref = ref_block(x) + out_test = converted(x) + + if isinstance(out_ref, tuple): + out_ref = out_ref[0] + if isinstance(out_test, tuple): + out_test = out_test[0] + assert torch.allclose(out_ref, out_test, atol=1e-5) + + def test_forward_calib_sends_all_tokens_to_all_experts(self): + """During calibration, all experts should see tokens (expert_token_count all > 0).""" + model = get_tiny_qwen3_moe() + register_sparse_moe_on_the_fly(model) + + def calib_fn(model): + x = model.dummy_inputs["input_ids"] + model(x) + + mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib_fn) + + for name, module in model.named_modules(): + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: + assert (module.expert_token_count > 0).all(), ( + f"Not all experts received tokens in {name}: {module.expert_token_count}" + ) + + def test_forward_calib_restores_top_k(self): + """After calibration forward, top_k should be restored to its original value.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + if TRANSFORMERS_VERSION_GE_5_0: + original_top_k = moe_block.gate.top_k + else: + original_top_k = moe_block.top_k + + converted = QuantModuleRegistry.convert(moe_block) + + # Simulate calibration mode: set _if_calib on a child TensorQuantizer + for m in converted.experts.modules(): + if hasattr(m, "_if_calib"): + m._if_calib = True + break + + x = torch.randn(1, 4, 32) + with torch.no_grad(): + converted(x) + + if TRANSFORMERS_VERSION_GE_5_0: + assert converted.gate.top_k == original_top_k + else: + assert converted.top_k == original_top_k + + def test_gate_forward_hook_counts_tokens(self): + """Verify the gate forward hook correctly counts expert token assignments.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + + # Reset counts and enable counting + converted.expert_token_count.zero_() + converted._count_expert_tokens = True + + hidden_size = converted.gate.in_features + x = torch.randn(8, hidden_size) + with torch.no_grad(): + converted.gate(x) + + # After one gate call with counting enabled, total assigned tokens should equal + # num_tokens * top_k + top_k = converted.top_k if hasattr(converted, "top_k") else converted.gate.top_k + total_assigned = converted.expert_token_count.sum().item() + assert total_assigned == 8 * top_k + + # Disable counting and verify counts don't change + converted._count_expert_tokens = False + prev_counts = converted.expert_token_count.clone() + with torch.no_grad(): + converted.gate(x) + assert torch.equal(converted.expert_token_count, prev_counts) From 4b4ef6342dc2491ed96af729bd35cad232c4d550 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 18 Feb 2026 08:14:52 +0000 Subject: [PATCH 5/8] Update Signed-off-by: Chenjie Luo --- examples/llm_ptq/hf_ptq.py | 2 ++ modelopt/torch/export/__init__.py | 1 + modelopt/torch/export/unified_export_hf.py | 3 --- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index de434e1cf..8bd052ed3 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -53,6 +53,7 @@ export_hf_checkpoint, export_tensorrt_llm_checkpoint, get_model_type, + save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration @@ -727,6 +728,7 @@ def post_quantize( if args.verbose: mtq.print_quant_summary(full_model) + save_expert_token_count_table(full_model, args.export_path) # Run some samples torch.cuda.empty_cache() diff --git a/modelopt/torch/export/__init__.py b/modelopt/torch/export/__init__.py index 8b2ba56f4..5c0905ba3 100644 --- a/modelopt/torch/export/__init__.py +++ b/modelopt/torch/export/__init__.py @@ -19,6 +19,7 @@ from .model_config import * from .model_config_export import * from .model_utils import * +from .moe_utils import * from .plugins import * from .transformer_engine import * from .unified_export_hf import * diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 99987eeb3..447fc43a7 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -76,7 +76,6 @@ QUANTIZATION_W4A8_NVFP4_FP8, ) from .model_utils import get_language_model_from_vl, is_multimodal_model -from .moe_utils import save_expert_token_count_table from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only from .quant_utils import ( fuse_prequant_layernorm, @@ -1004,8 +1003,6 @@ def export_hf_checkpoint( try: post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) - save_expert_token_count_table(model, export_dir) - if hf_quant_config is not None: # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: From 9b9377a77d7c5f398979e8a2c098c18238916e36 Mon Sep 17 00:00:00 2001 From: Chenjie Luo Date: Wed, 18 Feb 2026 18:26:46 +0000 Subject: [PATCH 6/8] Fix Signed-off-by: Chenjie Luo --- examples/llm_ptq/hf_ptq.py | 8 ++++-- modelopt/torch/export/moe_utils.py | 7 +++-- modelopt/torch/quantization/model_quant.py | 26 ++++++++++++++----- .../torch/quantization/plugins/huggingface.py | 10 +++++-- .../quantization/plugins/test_sparse_moe.py | 21 ++++++++++----- 5 files changed, 53 insertions(+), 19 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 8bd052ed3..d7aadf994 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -727,8 +727,12 @@ def post_quantize( """ if args.verbose: - mtq.print_quant_summary(full_model) - save_expert_token_count_table(full_model, args.export_path) + try: + mtq.print_quant_summary(full_model, args.export_path) + save_expert_token_count_table(full_model, args.export_path) + except Exception as e: + print(f"Error saving quant summary: {e}") + print("Continuing with generation...") # Run some samples torch.cuda.empty_cache() diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py index 87c2e1533..a5ba465b1 100644 --- a/modelopt/torch/export/moe_utils.py +++ b/modelopt/torch/export/moe_utils.py @@ -39,6 +39,9 @@ def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | Non return num_experts = rows[0][1].shape[0] + assert all(r[1].shape[0] == num_experts for r in rows), ( + "All MoE layers must have the same number of experts" + ) html_parts = [ "