diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 046ae2434..744238656 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -6,6 +6,8 @@ NVIDIA Model Optimizer Changelog (Linux) **New Features** +- User does not need to manually register MOE modules to cover experts calibration coverage in PTQ workflow. +- ``hf_ptq.py`` now saves the quantization summary and moe expert token count table to the export directory. - Add sparse attention optimization for transformer models (``modelopt.torch.sparsity.attention_sparsity``). This reduces computational cost by skipping attention computation. Supports calibration for threshold selection on HuggingFace models. See `examples/llm_sparsity/attention_sparsity/README.md `_ for usage. 0.42 (2026-02-xx) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index de434e1cf..d7aadf994 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -53,6 +53,7 @@ export_hf_checkpoint, export_tensorrt_llm_checkpoint, get_model_type, + save_expert_token_count_table, ) from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model from modelopt.torch.quantization.config import _default_disabled_quantizer_cfg, need_calibration @@ -726,7 +727,12 @@ def post_quantize( """ if args.verbose: - mtq.print_quant_summary(full_model) + try: + mtq.print_quant_summary(full_model, args.export_path) + save_expert_token_count_table(full_model, args.export_path) + except Exception as e: + print(f"Error saving quant summary: {e}") + print("Continuing with generation...") # Run some samples torch.cuda.empty_cache() diff --git a/modelopt/torch/export/__init__.py b/modelopt/torch/export/__init__.py index 8b2ba56f4..5c0905ba3 100644 --- a/modelopt/torch/export/__init__.py +++ b/modelopt/torch/export/__init__.py @@ -19,6 +19,7 @@ from .model_config import * from .model_config_export import * from .model_utils import * +from .moe_utils import * from .plugins import * from .transformer_engine import * from .unified_export_hf import * diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py new file mode 100644 index 000000000..a5ba465b1 --- /dev/null +++ b/modelopt/torch/export/moe_utils.py @@ -0,0 +1,77 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Mixture-of-Experts (MoE) model export.""" + +from pathlib import Path + +import torch.nn as nn + + +def save_expert_token_count_table(model: nn.Module, output_dir: str | Path | None = None): + """Collect expert_token_count from all quantized MoE layers and save as an HTML table. + + The table has rows for each MoE layer and columns for each expert, with cell values + showing the number of tokens routed to that expert during calibration. + + Args: + model: The model containing quantized MoE layers with ``expert_token_count`` attributes. + output_dir: Directory to save the HTML file. Defaults to current directory. + """ + rows = [] + for name, module in model.named_modules(): + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: + rows.append((name, module.expert_token_count)) + + if not rows: + return + + num_experts = rows[0][1].shape[0] + assert all(r[1].shape[0] == num_experts for r in rows), ( + "All MoE layers must have the same number of experts" + ) + html_parts = [ + "", + "

Expert Token Counts (per MoE layer)

", + "", + ] + html_parts.extend(f"" for i in range(num_experts)) + html_parts.append("") + + for name, counts in rows: + avg = counts.float().mean().item() + html_parts.append(f"") + for c in counts.tolist(): + if avg > 0 and c < avg * 0.05: + style = ' style="background: #ff6666;"' + elif avg > 0 and c < avg * 0.1: + style = ' style="background: #ffcccc;"' + else: + style = "" + html_parts.append(f"{c}") + html_parts.append("") + + html_parts.append("
Layer/Expert{i}
{name}
") + html_content = "\n".join(html_parts) + + if output_dir is None: + output_dir = Path(".") + output_path = Path(output_dir) / ".moe.html" + output_path.write_text(html_content, encoding="utf-8") + print(f"\033[1mExpert token count table saved to {output_path}\033[0m") diff --git a/modelopt/torch/quantization/model_quant.py b/modelopt/torch/quantization/model_quant.py index a14469326..0b40de8ab 100644 --- a/modelopt/torch/quantization/model_quant.py +++ b/modelopt/torch/quantization/model_quant.py @@ -508,14 +508,26 @@ def enable_quantizer(model: nn.Module, wildcard_or_filter_func: str | Callable): @atomic_print -def print_quant_summary(model: nn.Module): +def print_quant_summary(model: nn.Module, output_dir: str | None = None): """Print summary of all quantizer modules in the model.""" - count = 0 - for name, mod in model.named_modules(): - if isinstance(mod, TensorQuantizer): - print(f"{name:80} {mod}") - count += 1 - print(f"{count} TensorQuantizers found in model") + lines = [ + f"{name:80} {mod}" + for name, mod in model.named_modules() + if isinstance(mod, TensorQuantizer) + ] + lines.append(f"{len(lines)} TensorQuantizers found in model") + + if output_dir: + path = ( + output_dir.joinpath(".quant_summary.txt") + if hasattr(output_dir, "joinpath") + else f"{output_dir}/.quant_summary.txt" + ) + with open(path, "w", encoding="utf-8") as f: + f.write("\n".join(lines) + "\n") + print(f"\033[1mQuant summary saved to {path}\033[0m") + else: + print("\n".join(lines)) def fold_weight(model: nn.Module): diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 807c92c2c..aa274ea7e 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -450,20 +450,56 @@ class _QuantSparseMoe(QuantModule): """ def _setup(self): - pass + num_experts = 0 + if hasattr(self, "gate") and hasattr(self.gate, "num_experts"): + num_experts = self.gate.num_experts + elif hasattr(self, "num_experts"): + num_experts = self.num_experts + elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): + num_experts = self.experts.num_experts + + self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu") + self._count_expert_tokens = False + + if num_experts == 0: + warnings.warn( + f"{self.__class__.__name__}: could not resolve num_experts; " + "expert routing will not be tracked for this layer." + ) + return + + if hasattr(self, "gate"): + self.gate.register_forward_hook(self._gate_forward_hook) + + def _gate_forward_hook(self, module, input, output): + if not self._count_expert_tokens: + return + with torch.no_grad(): + if isinstance(output, tuple) and len(output) >= 3: + # v5.x TopKRouter: returns (logits, scores, indices) + indices = output[2] + else: + # v4.x nn.Linear gate: returns logits tensor + logits = output if not isinstance(output, tuple) else output[0] + top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k + _, indices = torch.topk(logits.float(), top_k, dim=-1) + counts = torch.bincount( + indices.reshape(-1).cpu(), minlength=len(self.expert_token_count) + ) + self.expert_token_count += counts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): + is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) + if is_calib: # If any of the experts are in calibration mode, we will forward all tokens to all experts # This is used only for calibration, we need to re-calculate the actual outputs again using # the original top_k if TRANSFORMERS_VERSION_GE_5_0: - assert hasattr(self, "gate") - # Path for transformers >= 5.0 - original_top_k = self.gate.topk - self.gate.topk = self.gate.num_experts + assert hasattr(self, "gate") and hasattr(self.gate, "top_k") + original_top_k = self.gate.top_k + self.gate.top_k = self.gate.num_experts super().forward(hidden_states) - self.gate.topk = original_top_k + self.gate.top_k = original_top_k else: # Path for transformers < 5.0 original_top_k = self.top_k @@ -475,7 +511,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: raise ValueError(f"Could not find num_experts in module {self}") super().forward(hidden_states) self.top_k = original_top_k - return super().forward(hidden_states) + # Enable counting only for the real-routing forward during calibration + self._count_expert_tokens = is_calib + output = super().forward(hidden_states) + self._count_expert_tokens = False + return output class _QuantLlama4TextExperts(QuantModule): @@ -765,10 +805,7 @@ def unpack_weight(self): try: - from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe - - if Llama4TextMoe not in QuantModuleRegistry: - QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe) + from transformers.models.llama4.modeling_llama4 import Llama4TextExperts if Llama4TextExperts not in QuantModuleRegistry: QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})( @@ -791,16 +828,6 @@ def unpack_weight(self): except ImportError: pass -try: - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - - if MixtralSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - try: from transformers.models.falcon.modeling_falcon import FalconLinear @@ -809,36 +836,6 @@ def unpack_weight(self): except ImportError: pass -try: - from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock - - if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - -try: - from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock - - if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - -try: - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock - - if Qwen3NextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - try: from compressed_tensors.linear.compressed_linear import CompressedLinear @@ -850,15 +847,7 @@ def unpack_weight(self): pass try: - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( - Qwen3VLMoeTextExperts, - Qwen3VLMoeTextSparseMoeBlock, - ) - - if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register( - {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"} - )(_QuantSparseMoe) + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts if Qwen3VLMoeTextExperts not in QuantModuleRegistry: QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( @@ -989,15 +978,56 @@ def register_falcon_linears_on_the_fly(model): QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) -def register_minimax_m2_moe_on_the_fly(model): - """Register MiniMax M2 MoE modules as a QUANT_MODULE. +def _is_sparse_moe_block(module): + """Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe. + + All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.) + share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes + (``top_k`` and ``num_experts``), and an ``experts`` sub-module. - MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly. + This function detects that pattern instead of relying on class names, making it forward-compatible + with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but + use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom + ``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives. """ - if type(model).__name__ in ["MiniMaxM2ForCausalLM"]: - moe_type = type(model.model.layers[0].block_sparse_moe) - if QuantModuleRegistry.get(moe_type) is None: - QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe) + if not hasattr(module, "experts"): + return False + + # Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern) + if hasattr(module, "gate"): + gate = module.gate + has_topk = hasattr(gate, "top_k") + has_num_experts = hasattr(gate, "num_experts") + if has_topk and has_num_experts: + return True + + # Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next) + return hasattr(module, "top_k") and hasattr(module, "num_experts") + + +def register_sparse_moe_on_the_fly(model): + """Auto-detect and register MOE modules as _QuantSparseMoe. + + Walks the model tree, identifies MoE blocks by their structural attributes + (``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``. + """ + visited_types = set() + for name, module in model.named_modules(): + mod_type = type(module) + + # Avoid duplicate registration: skip if we already processed this type + # in this walk, or if it was previously registered in the QuantModuleRegistry. + if mod_type in visited_types or QuantModuleRegistry.get(mod_type) is not None: + continue + + visited_types.add(mod_type) + + if _is_sparse_moe_block(module): + print( + f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, " + f"registering with _QuantSparseMoe.\033[0m" + ) + QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe) def _is_supported_hf_model(model): @@ -1065,7 +1095,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): [ register_falcon_linears_on_the_fly, register_dbrx_moe_on_the_fly, - register_minimax_m2_moe_on_the_fly, + register_sparse_moe_on_the_fly, register_hf_attentions_on_the_fly, convert_hf_parallel_linears_on_the_fly, ] diff --git a/tests/unit/torch/quantization/plugins/test_sparse_moe.py b/tests/unit/torch/quantization/plugins/test_sparse_moe.py new file mode 100644 index 000000000..6d548aa40 --- /dev/null +++ b/tests/unit/torch/quantization/plugins/test_sparse_moe.py @@ -0,0 +1,326 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for _is_sparse_moe_block and _QuantSparseMoe.""" + +import pytest +import torch +import torch.nn as nn + +pytest.importorskip("transformers") + +from _test_utils.torch.transformers_models import get_tiny_qwen3_moe + +import modelopt.torch.quantization as mtq +from modelopt.torch.quantization.nn import QuantModuleRegistry +from modelopt.torch.quantization.plugins.huggingface import ( + TRANSFORMERS_VERSION_GE_5_0, + _is_sparse_moe_block, + register_sparse_moe_on_the_fly, +) + + +# --------------------------------------------------------------------------- +# Helpers: lightweight mock modules for _is_sparse_moe_block +# --------------------------------------------------------------------------- +class _FakeGateWithRouter(nn.Module): + """Mimics a v5.x TopKRouter gate with top_k and num_experts.""" + + def __init__(self, top_k=2, num_experts=4): + super().__init__() + self.top_k = top_k + self.num_experts = num_experts + self.linear = nn.Linear(8, num_experts) + + def forward(self, x): + return self.linear(x) + + +class _FakeExperts(nn.ModuleList): + def __init__(self, n=4): + super().__init__([nn.Linear(8, 8) for _ in range(n)]) + self.num_experts = n + + +class _MoEBlockWithGateRouter(nn.Module): + """Matches the primary detection path: gate.top_k + gate.num_experts.""" + + def __init__(self, num_experts=4, top_k=2): + super().__init__() + self.gate = _FakeGateWithRouter(top_k=top_k, num_experts=num_experts) + self.experts = _FakeExperts(num_experts) + + def forward(self, hidden_states): + logits = self.gate(hidden_states) + routing_weights, selected = torch.topk(logits, self.gate.top_k, dim=-1) + out = torch.zeros_like(hidden_states) + for i in range(self.gate.num_experts): + mask = (selected == i).any(dim=-1) + if mask.any(): + out[mask] += self.experts[i](hidden_states[mask]) + return out + + +class _MoEBlockFallback(nn.Module): + """Matches the fallback path: top_k + num_experts on the block itself.""" + + def __init__(self, num_experts=4, top_k=2): + super().__init__() + self.num_experts = num_experts + self.top_k = top_k + self.gate = nn.Linear(8, num_experts) + self.experts = _FakeExperts(num_experts) + + def forward(self, hidden_states): + logits = self.gate(hidden_states) + routing_weights, selected = torch.topk(logits, self.top_k, dim=-1) + out = torch.zeros_like(hidden_states) + for i in range(self.num_experts): + mask = (selected == i).any(dim=-1) + if mask.any(): + out[mask] += self.experts[i](hidden_states[mask]) + return out + + +# --------------------------------------------------------------------------- +# Tests for _is_sparse_moe_block +# --------------------------------------------------------------------------- +class TestIsSparseBlock: + def test_no_experts_returns_false(self): + module = nn.Linear(8, 8) + assert _is_sparse_moe_block(module) is False + + def test_experts_but_no_gate_or_topk_returns_false(self): + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + assert _is_sparse_moe_block(module) is False + + def test_gate_with_router_attrs_returns_true(self): + block = _MoEBlockWithGateRouter(num_experts=4, top_k=2) + assert _is_sparse_moe_block(block) is True + + def test_fallback_block_level_attrs_returns_true(self): + block = _MoEBlockFallback(num_experts=4, top_k=2) + assert _is_sparse_moe_block(block) is True + + def test_gate_missing_num_experts_returns_false(self): + """gate.top_k present but gate.num_experts absent -> primary path fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.top_k = 2 + module.gate = gate + assert _is_sparse_moe_block(module) is False + + def test_gate_missing_top_k_returns_false(self): + """gate.num_experts present but gate.top_k absent -> primary path fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.num_experts = 4 + module.gate = gate + assert _is_sparse_moe_block(module) is False + + def test_block_level_only_top_k_returns_false(self): + """Only top_k on block (no num_experts) -> fallback fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + module.top_k = 2 + assert _is_sparse_moe_block(module) is False + + def test_block_level_only_num_experts_returns_false(self): + """Only num_experts on block (no top_k) -> fallback fails.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + module.num_experts = 4 + assert _is_sparse_moe_block(module) is False + + def test_glm4_like_block_rejected(self): + """A module with n_routed_experts instead of num_experts should be rejected.""" + module = nn.Module() + module.experts = nn.ModuleList([nn.Linear(8, 8)]) + gate = nn.Module() + gate.top_k = 2 + gate.n_routed_experts = 4 # different attr name + module.gate = gate + assert _is_sparse_moe_block(module) is False + + +# --------------------------------------------------------------------------- +# Tests for _QuantSparseMoe +# --------------------------------------------------------------------------- +class TestQuantSparseMoe: + """Tests for _QuantSparseMoe using a real tiny Qwen3Moe model.""" + + @staticmethod + def _get_moe_block(model): + """Return the first MoE block from the model.""" + for module in model.modules(): + if _is_sparse_moe_block(module): + return module + raise RuntimeError("No MoE block found in model") + + def test_register_sparse_moe_on_the_fly(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is not None: + pytest.skip("MoE type already registered (upstream change)") + + register_sparse_moe_on_the_fly(model) + assert QuantModuleRegistry.get(moe_type) is not None + + def test_setup_creates_expert_token_count(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + assert hasattr(converted, "expert_token_count") + if hasattr(moe_block, "gate") and hasattr(moe_block.gate, "num_experts"): + expected_num_experts = moe_block.gate.num_experts + elif hasattr(moe_block, "num_experts"): + expected_num_experts = moe_block.num_experts + elif hasattr(moe_block, "experts") and hasattr(moe_block.experts, "num_experts"): + expected_num_experts = moe_block.experts.num_experts + else: + expected_num_experts = 0 + assert converted.expert_token_count.shape == (expected_num_experts,) + assert converted.expert_token_count.dtype == torch.long + assert (converted.expert_token_count == 0).all() + + def test_setup_count_expert_tokens_default_false(self): + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + assert converted._count_expert_tokens is False + + def test_forward_no_calib_matches_original(self): + """When calibration is off, _QuantSparseMoe should produce the same output as the original.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + ref_block = self._get_moe_block(get_tiny_qwen3_moe()) + ref_block.load_state_dict(moe_block.state_dict()) + + converted = QuantModuleRegistry.convert(moe_block) + + torch.manual_seed(42) + x = torch.randn(1, 4, 32) + with torch.no_grad(): + out_ref = ref_block(x) + out_test = converted(x) + + if isinstance(out_ref, tuple): + out_ref = out_ref[0] + if isinstance(out_test, tuple): + out_test = out_test[0] + assert torch.allclose(out_ref, out_test, atol=1e-5) + + def test_forward_calib_sends_all_tokens_to_all_experts(self): + """During calibration, all experts should see tokens (expert_token_count all > 0).""" + model = get_tiny_qwen3_moe() + register_sparse_moe_on_the_fly(model) + + def calib_fn(model): + x = model.dummy_inputs["input_ids"] + model(x) + + mtq.quantize(model, mtq.INT8_DEFAULT_CFG, calib_fn) + + for name, module in model.named_modules(): + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: + assert (module.expert_token_count > 0).all(), ( + f"Not all experts received tokens in {name}: {module.expert_token_count}" + ) + + def test_forward_calib_restores_top_k(self): + """After calibration forward, top_k should be restored to its original value.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + if TRANSFORMERS_VERSION_GE_5_0: + original_top_k = moe_block.gate.top_k + else: + original_top_k = moe_block.top_k + + converted = QuantModuleRegistry.convert(moe_block) + + # Simulate calibration mode: set _if_calib on a child TensorQuantizer + for m in converted.experts.modules(): + if hasattr(m, "_if_calib"): + m._if_calib = True + break + + x = torch.randn(1, 4, 32) + with torch.no_grad(): + converted(x) + + if TRANSFORMERS_VERSION_GE_5_0: + assert converted.gate.top_k == original_top_k + else: + assert converted.top_k == original_top_k + + def test_gate_forward_hook_counts_tokens(self): + """Verify the gate forward hook correctly counts expert token assignments.""" + model = get_tiny_qwen3_moe() + moe_block = self._get_moe_block(model) + moe_type = type(moe_block) + + if QuantModuleRegistry.get(moe_type) is None: + register_sparse_moe_on_the_fly(model) + + converted = QuantModuleRegistry.convert(moe_block) + + # Reset counts and enable counting + converted.expert_token_count.zero_() + converted._count_expert_tokens = True + + if TRANSFORMERS_VERSION_GE_5_0: + hidden_size = converted.gate.weight.shape[1] + top_k = converted.gate.top_k + else: + hidden_size = converted.gate.in_features + top_k = converted.top_k if hasattr(converted, "top_k") else converted.gate.top_k + + x = torch.randn(8, hidden_size) + with torch.no_grad(): + converted.gate(x) + total_assigned = converted.expert_token_count.sum().item() + assert total_assigned == 8 * top_k + + # Disable counting and verify counts don't change + converted._count_expert_tokens = False + prev_counts = converted.expert_token_count.clone() + with torch.no_grad(): + converted.gate(x) + assert torch.equal(converted.expert_token_count, prev_counts)