diff --git a/modelopt/torch/export/moe_utils.py b/modelopt/torch/export/moe_utils.py new file mode 100644 index 000000000..04d736ebd --- /dev/null +++ b/modelopt/torch/export/moe_utils.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Mixture-of-Experts (MoE) model export.""" + +from pathlib import Path + +import torch.nn as nn + + +def sync_expert_amax_low_tokens(model: nn.Module, threshold: float = 0.05): + """Sync expert amax values across experts if the number of tokens routed to an expert is less than the threshold. + + For each MoE layer, this function collects the maximum amax value across all experts for + each input quantizer, then overwrites the amax of experts whose token count falls below + ``threshold * mean_token_count`` with that maximum. + """ + for module_name, module in model.named_modules(): + if not (hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0): + continue + + experts = list(module.experts.children()) + num_experts = module.expert_token_count.shape[0] + + expert_amax_max = {} + for expert in experts: + for quantizer_name, quantizer in expert.named_modules(): + # We do not sync amax for AWQ. + if hasattr(quantizer, "pre_quant_scale") and quantizer.pre_quant_scale is not None: + return + + if ( + "input_quantizer" in quantizer_name + and hasattr(quantizer, "_amax") + and quantizer._amax is not None + and quantizer._amax.numel() == 1 + ): + prev = expert_amax_max.get(quantizer_name) + cur = quantizer._amax.detach().clone() + if prev is None or cur > prev: + expert_amax_max[quantizer_name] = cur + + if not expert_amax_max: + continue + + avg_token_count = module.expert_token_count.float().mean().item() + token_threshold = avg_token_count * threshold + + print(f"[sync_expert_amax] {module_name}") + print(f" token counts : {module.expert_token_count.tolist()}") + print(f" avg={avg_token_count:.1f} threshold(<)={token_threshold:.1f}") + print(f" tracked quantizers: {list(expert_amax_max.keys())}") + + for i in range(num_experts): + token_count_i = module.expert_token_count[i].item() + if token_count_i < token_threshold: + expert_i = experts[i] + print(f" expert {i}: token_count={token_count_i} — syncing amax") + for quantizer_name, quantizer in expert_i.named_modules(): + if quantizer_name in expert_amax_max: + old_val = quantizer._amax.item() if quantizer._amax is not None else None + quantizer._amax = expert_amax_max[quantizer_name].clone() + print(f" {quantizer_name}: {old_val} -> {quantizer._amax.item()}") + + +def save_expert_token_count_table( + model: nn.Module, output_dir: str | Path | None = None, threshold: float = 0.05 +): + """Collect expert_token_count from all quantized MoE layers and save as an HTML table. + + The table has rows for each MoE layer and columns for each expert, with cell values + showing the number of tokens routed to that expert during calibration. + + Args: + model: The model containing quantized MoE layers with ``expert_token_count`` attributes. + output_dir: Directory to save the HTML file. Defaults to current directory. + threshold: Threshold for low token count to sync amax. Defaults to 0.05. + """ + rows = [] + for name, module in model.named_modules(): + if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0: + rows.append((name, module.expert_token_count)) + + if not rows: + return + + num_experts = rows[0][1].shape[0] + html_parts = [ + "", + "

Expert Token Counts (per MoE layer)

", + "", + ] + html_parts.extend(f"" for i in range(num_experts)) + html_parts.append("") + + for name, counts in rows: + avg = counts.float().mean().item() + html_parts.append(f"") + for c in counts.tolist(): + if avg > 0 and c < avg * threshold: + style = ' style="background: #ffcccc;"' + else: + style = "" + html_parts.append(f"{c}") + html_parts.append("") + + html_parts.append("
Layer/Expert{i}
{name}
") + html_content = "\n".join(html_parts) + + if output_dir is None: + output_dir = Path(".") + output_path = Path(output_dir) / "moe.html" + output_path.write_text(html_content) + print(f"Expert token count table saved to {output_path}") diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 447fc43a7..a5d948836 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -76,6 +76,7 @@ QUANTIZATION_W4A8_NVFP4_FP8, ) from .model_utils import get_language_model_from_vl, is_multimodal_model +from .moe_utils import save_expert_token_count_table, sync_expert_amax_low_tokens from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only from .quant_utils import ( fuse_prequant_layernorm, @@ -1003,6 +1004,9 @@ def export_hf_checkpoint( try: post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) + save_expert_token_count_table(model, export_dir) + sync_expert_amax_low_tokens(model) + if hf_quant_config is not None: # Save hf_quant_config.json for backward compatibility with open(f"{export_dir}/hf_quant_config.json", "w") as file: diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index 807c92c2c..69f56eef2 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -450,20 +450,58 @@ class _QuantSparseMoe(QuantModule): """ def _setup(self): - pass + num_experts = 0 + if hasattr(self, "gate") and hasattr(self.gate, "num_experts"): + num_experts = self.gate.num_experts + elif hasattr(self, "num_experts"): + num_experts = self.num_experts + elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"): + num_experts = self.experts.num_experts + + self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu") + self._count_expert_tokens = False + + if hasattr(self, "gate"): + self.gate.register_forward_hook(self._gate_forward_hook) + + def _gate_forward_hook(self, module, input, output): + if not self._count_expert_tokens: + return + with torch.no_grad(): + if isinstance(output, tuple) and len(output) >= 3: + # v5.x TopKRouter: returns (logits, scores, indices) + indices = output[2] + else: + # v4.x nn.Linear gate: returns logits tensor + logits = output if not isinstance(output, tuple) else output[0] + top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k + _, indices = torch.topk(logits.float(), top_k, dim=-1) + counts = torch.bincount( + indices.reshape(-1).cpu(), minlength=len(self.expert_token_count) + ) + self.expert_token_count += counts def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): - # If any of the experts are in calibration mode, we will forward all tokens to all experts + is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules()) + has_prequant_scale = any( + getattr(m, "awq_lite", None) is not None or getattr(m, "awq_clip", None) is not None + for m in self.experts.modules() + ) + self._count_expert_tokens = is_calib + import pdb + + pdb.set_trace() + if is_calib and has_prequant_scale: + # If any of the experts are in AWQ calibration mode, we will forward all tokens to all experts # This is used only for calibration, we need to re-calculate the actual outputs again using # the original top_k if TRANSFORMERS_VERSION_GE_5_0: assert hasattr(self, "gate") # Path for transformers >= 5.0 - original_top_k = self.gate.topk - self.gate.topk = self.gate.num_experts + original_top_k = self.gate.top_k + self.gate.top_k = self.gate.num_experts super().forward(hidden_states) - self.gate.topk = original_top_k + self.gate.top_k = original_top_k else: # Path for transformers < 5.0 original_top_k = self.top_k @@ -475,7 +513,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: raise ValueError(f"Could not find num_experts in module {self}") super().forward(hidden_states) self.top_k = original_top_k - return super().forward(hidden_states) + # Enable counting only for the real-routing forward during calibration + output = super().forward(hidden_states) + self._count_expert_tokens = False + return output class _QuantLlama4TextExperts(QuantModule): @@ -765,10 +806,7 @@ def unpack_weight(self): try: - from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe - - if Llama4TextMoe not in QuantModuleRegistry: - QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe) + from transformers.models.llama4.modeling_llama4 import Llama4TextExperts if Llama4TextExperts not in QuantModuleRegistry: QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})( @@ -791,16 +829,6 @@ def unpack_weight(self): except ImportError: pass -try: - from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock - - if MixtralSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - try: from transformers.models.falcon.modeling_falcon import FalconLinear @@ -809,36 +837,6 @@ def unpack_weight(self): except ImportError: pass -try: - from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock - - if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - -try: - from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock - - if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - -try: - from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock - - if Qwen3NextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})( - _QuantSparseMoe - ) -except ImportError: - pass - try: from compressed_tensors.linear.compressed_linear import CompressedLinear @@ -850,15 +848,7 @@ def unpack_weight(self): pass try: - from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import ( - Qwen3VLMoeTextExperts, - Qwen3VLMoeTextSparseMoeBlock, - ) - - if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry: - QuantModuleRegistry.register( - {Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"} - )(_QuantSparseMoe) + from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts if Qwen3VLMoeTextExperts not in QuantModuleRegistry: QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})( @@ -989,15 +979,55 @@ def register_falcon_linears_on_the_fly(model): QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear) -def register_minimax_m2_moe_on_the_fly(model): - """Register MiniMax M2 MoE modules as a QUANT_MODULE. +def _is_sparse_moe_block(module): + """Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe. + + All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.) + share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes + (``top_k`` and ``num_experts``), and an ``experts`` sub-module. - MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly. + This function detects that pattern instead of relying on class names, making it forward-compatible + with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but + use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom + ``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives. """ - if type(model).__name__ in ["MiniMaxM2ForCausalLM"]: - moe_type = type(model.model.layers[0].block_sparse_moe) - if QuantModuleRegistry.get(moe_type) is None: - QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe) + if not hasattr(module, "experts"): + return False + + # Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern) + if hasattr(module, "gate"): + gate = module.gate + has_topk = hasattr(gate, "top_k") + has_num_experts = hasattr(gate, "num_experts") + if has_topk and has_num_experts: + return True + + # Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next) + return hasattr(module, "top_k") and hasattr(module, "num_experts") + + +def register_sparse_moe_on_the_fly(model): + """Auto-detect and register MOE modules as _QuantSparseMoe. + + Walks the model tree, identifies MoE blocks by their structural attributes + (``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``. + """ + registered_types = set() + for name, module in model.named_modules(): + mod_type = type(module) + + # Avoid duplicate registration: skip if we already processed this type + # in this walk, or if it was previously registered in the QuantModuleRegistry. + if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None: + continue + + if _is_sparse_moe_block(module): + print( + f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, " + f"registering with _QuantSparseMoe.\033[0m" + ) + QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe) + registered_types.add(mod_type) def _is_supported_hf_model(model): @@ -1065,7 +1095,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model): [ register_falcon_linears_on_the_fly, register_dbrx_moe_on_the_fly, - register_minimax_m2_moe_on_the_fly, + register_sparse_moe_on_the_fly, register_hf_attentions_on_the_fly, convert_hf_parallel_linears_on_the_fly, ]