Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions modelopt/torch/export/moe_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for Mixture-of-Experts (MoE) model export."""

from pathlib import Path

import torch.nn as nn


def sync_expert_amax_low_tokens(model: nn.Module, threshold: float = 0.05):
"""Sync expert amax values across experts if the number of tokens routed to an expert is less than the threshold.

For each MoE layer, this function collects the maximum amax value across all experts for
each input quantizer, then overwrites the amax of experts whose token count falls below
``threshold * mean_token_count`` with that maximum.
"""
for module_name, module in model.named_modules():
if not (hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0):
continue

experts = list(module.experts.children())
num_experts = module.expert_token_count.shape[0]

expert_amax_max = {}
for expert in experts:
for quantizer_name, quantizer in expert.named_modules():
# We do not sync amax for AWQ.
if hasattr(quantizer, "pre_quant_scale") and quantizer.pre_quant_scale is not None:
return

if (
"input_quantizer" in quantizer_name
and hasattr(quantizer, "_amax")
and quantizer._amax is not None
and quantizer._amax.numel() == 1
):
prev = expert_amax_max.get(quantizer_name)
cur = quantizer._amax.detach().clone()
if prev is None or cur > prev:
expert_amax_max[quantizer_name] = cur

if not expert_amax_max:
continue

avg_token_count = module.expert_token_count.float().mean().item()
token_threshold = avg_token_count * threshold

print(f"[sync_expert_amax] {module_name}")
print(f" token counts : {module.expert_token_count.tolist()}")
print(f" avg={avg_token_count:.1f} threshold(<)={token_threshold:.1f}")
print(f" tracked quantizers: {list(expert_amax_max.keys())}")

for i in range(num_experts):
token_count_i = module.expert_token_count[i].item()
if token_count_i < token_threshold:
expert_i = experts[i]
print(f" expert {i}: token_count={token_count_i} — syncing amax")
for quantizer_name, quantizer in expert_i.named_modules():
if quantizer_name in expert_amax_max:
old_val = quantizer._amax.item() if quantizer._amax is not None else None
quantizer._amax = expert_amax_max[quantizer_name].clone()
print(f" {quantizer_name}: {old_val} -> {quantizer._amax.item()}")


def save_expert_token_count_table(
model: nn.Module, output_dir: str | Path | None = None, threshold: float = 0.05
):
"""Collect expert_token_count from all quantized MoE layers and save as an HTML table.

The table has rows for each MoE layer and columns for each expert, with cell values
showing the number of tokens routed to that expert during calibration.

Args:
model: The model containing quantized MoE layers with ``expert_token_count`` attributes.
output_dir: Directory to save the HTML file. Defaults to current directory.
threshold: Threshold for low token count to sync amax. Defaults to 0.05.
"""
rows = []
for name, module in model.named_modules():
if hasattr(module, "expert_token_count") and module.expert_token_count.numel() > 0:
rows.append((name, module.expert_token_count))

if not rows:
return

num_experts = rows[0][1].shape[0]
html_parts = [
"<html><head><style>",
"table { border-collapse: collapse; font-family: monospace; }",
"th, td { border: 1px solid #ccc; padding: 4px 8px; text-align: right; }",
"th { background: #f0f0f0; }",
"</style></head><body>",
"<h2>Expert Token Counts (per MoE layer)</h2>",
"<table><tr><th>Layer/Expert</th>",
]
html_parts.extend(f"<th>{i}</th>" for i in range(num_experts))
html_parts.append("</tr>")

for name, counts in rows:
avg = counts.float().mean().item()
html_parts.append(f"<tr><td>{name}</td>")
for c in counts.tolist():
if avg > 0 and c < avg * threshold:
style = ' style="background: #ffcccc;"'
else:
style = ""
html_parts.append(f"<td{style}>{c}</td>")
html_parts.append("</tr>")

html_parts.append("</table></body></html>")
html_content = "\n".join(html_parts)

if output_dir is None:
output_dir = Path(".")
output_path = Path(output_dir) / "moe.html"
output_path.write_text(html_content)
print(f"Expert token count table saved to {output_path}")
4 changes: 4 additions & 0 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .model_utils import get_language_model_from_vl, is_multimodal_model
from .moe_utils import save_expert_token_count_table, sync_expert_amax_low_tokens
from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only
from .quant_utils import (
fuse_prequant_layernorm,
Expand Down Expand Up @@ -1003,6 +1004,9 @@ def export_hf_checkpoint(
try:
post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype)

save_expert_token_count_table(model, export_dir)
sync_expert_amax_low_tokens(model)

if hf_quant_config is not None:
# Save hf_quant_config.json for backward compatibility
with open(f"{export_dir}/hf_quant_config.json", "w") as file:
Expand Down
166 changes: 98 additions & 68 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,20 +450,58 @@ class _QuantSparseMoe(QuantModule):
"""

def _setup(self):
pass
num_experts = 0
if hasattr(self, "gate") and hasattr(self.gate, "num_experts"):
num_experts = self.gate.num_experts
elif hasattr(self, "num_experts"):
num_experts = self.num_experts
elif hasattr(self, "experts") and hasattr(self.experts, "num_experts"):
num_experts = self.experts.num_experts

self.expert_token_count = torch.zeros(num_experts, dtype=torch.long, device="cpu")
self._count_expert_tokens = False

if hasattr(self, "gate"):
self.gate.register_forward_hook(self._gate_forward_hook)

def _gate_forward_hook(self, module, input, output):
if not self._count_expert_tokens:
return
with torch.no_grad():
if isinstance(output, tuple) and len(output) >= 3:
# v5.x TopKRouter: returns (logits, scores, indices)
indices = output[2]
else:
# v4.x nn.Linear gate: returns logits tensor
logits = output if not isinstance(output, tuple) else output[0]
top_k = self.gate.top_k if hasattr(self.gate, "top_k") else self.top_k
_, indices = torch.topk(logits.float(), top_k, dim=-1)
counts = torch.bincount(
indices.reshape(-1).cpu(), minlength=len(self.expert_token_count)
)
self.expert_token_count += counts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
# If any of the experts are in calibration mode, we will forward all tokens to all experts
is_calib = any(getattr(m, "_if_calib", False) for m in self.experts.modules())
has_prequant_scale = any(
getattr(m, "awq_lite", None) is not None or getattr(m, "awq_clip", None) is not None
for m in self.experts.modules()
)
self._count_expert_tokens = is_calib
import pdb

pdb.set_trace()
if is_calib and has_prequant_scale:
# If any of the experts are in AWQ calibration mode, we will forward all tokens to all experts
# This is used only for calibration, we need to re-calculate the actual outputs again using
# the original top_k
if TRANSFORMERS_VERSION_GE_5_0:
assert hasattr(self, "gate")
# Path for transformers >= 5.0
original_top_k = self.gate.topk
self.gate.topk = self.gate.num_experts
original_top_k = self.gate.top_k
self.gate.top_k = self.gate.num_experts
super().forward(hidden_states)
self.gate.topk = original_top_k
self.gate.top_k = original_top_k
else:
# Path for transformers < 5.0
original_top_k = self.top_k
Expand All @@ -475,7 +513,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
raise ValueError(f"Could not find num_experts in module {self}")
super().forward(hidden_states)
self.top_k = original_top_k
return super().forward(hidden_states)
# Enable counting only for the real-routing forward during calibration
output = super().forward(hidden_states)
self._count_expert_tokens = False
return output


class _QuantLlama4TextExperts(QuantModule):
Expand Down Expand Up @@ -765,10 +806,7 @@ def unpack_weight(self):


try:
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts, Llama4TextMoe

if Llama4TextMoe not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextMoe: "hf.Llama4TextMoe"})(_QuantSparseMoe)
from transformers.models.llama4.modeling_llama4 import Llama4TextExperts

if Llama4TextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Llama4TextExperts: "hf.Llama4TextExperts"})(
Expand All @@ -791,16 +829,6 @@ def unpack_weight(self):
except ImportError:
pass

try:
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock

if MixtralSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({MixtralSparseMoeBlock: "hf.MixtralSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.falcon.modeling_falcon import FalconLinear

Expand All @@ -809,36 +837,6 @@ def unpack_weight(self):
except ImportError:
pass

try:
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeSparseMoeBlock

if Qwen3MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3MoeSparseMoeBlock: "hf.Qwen3MoeSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock

if Qwen2MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen2MoeSparseMoeBlock: "hf.Qwen2MoeSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextSparseMoeBlock

if Qwen3NextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3NextSparseMoeBlock: "hf.Qwen3NextSparseMoeBlock"})(
_QuantSparseMoe
)
except ImportError:
pass

try:
from compressed_tensors.linear.compressed_linear import CompressedLinear

Expand All @@ -850,15 +848,7 @@ def unpack_weight(self):
pass

try:
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import (
Qwen3VLMoeTextExperts,
Qwen3VLMoeTextSparseMoeBlock,
)

if Qwen3VLMoeTextSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register(
{Qwen3VLMoeTextSparseMoeBlock: "hf.Qwen3VLMoeTextSparseMoeBlock"}
)(_QuantSparseMoe)
from transformers.models.qwen3_vl_moe.modeling_qwen3_vl_moe import Qwen3VLMoeTextExperts

if Qwen3VLMoeTextExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3VLMoeTextExperts: "hf.Qwen3VLMoeTextExperts"})(
Expand Down Expand Up @@ -989,15 +979,55 @@ def register_falcon_linears_on_the_fly(model):
QuantModuleRegistry.register({linear_type: linear_type.__name__})(_QuantLinear)


def register_minimax_m2_moe_on_the_fly(model):
"""Register MiniMax M2 MoE modules as a QUANT_MODULE.
def _is_sparse_moe_block(module):
"""Check if a module is structurally a sparse MoE block compatible with _QuantSparseMoe.

All HuggingFace MoE blocks (Mixtral, Qwen3Moe, Qwen2Moe, Qwen3Next, Llama4, MiniMax, etc.)
share a common structural pattern: a ``gate`` (TopKRouter) sub-module with routing attributes
(``top_k`` and ``num_experts``), and an ``experts`` sub-module.

MiniMax M2 MoE modules are defined in the model card, so we need to register them on the fly.
This function detects that pattern instead of relying on class names, making it forward-compatible
with new MoE architectures. Some MoE models (e.g. Glm4MoeMoE) have ``gate`` and ``experts`` but
use a different routing interface (``n_routed_experts`` instead of ``num_experts``, custom
``route_tokens_to_experts``), so we require ``num_experts`` to be present to avoid false positives.
"""
if type(model).__name__ in ["MiniMaxM2ForCausalLM"]:
moe_type = type(model.model.layers[0].block_sparse_moe)
if QuantModuleRegistry.get(moe_type) is None:
QuantModuleRegistry.register({moe_type: moe_type.__name__})(_QuantSparseMoe)
if not hasattr(module, "experts"):
return False

# Primary: gate sub-module has topk/top_k + num_experts (standard TopKRouter pattern)
if hasattr(module, "gate"):
gate = module.gate
has_topk = hasattr(gate, "top_k")
has_num_experts = hasattr(gate, "num_experts")
if has_topk and has_num_experts:
return True

# Fallback: top_k + num_experts on the block itself (older transformers, e.g. v4.x Qwen3Next)
return hasattr(module, "top_k") and hasattr(module, "num_experts")


def register_sparse_moe_on_the_fly(model):
"""Auto-detect and register MOE modules as _QuantSparseMoe.

Walks the model tree, identifies MoE blocks by their structural attributes
(``gate`` + ``experts``), and registers unregistered ones with ``_QuantSparseMoe``.
"""
registered_types = set()
for name, module in model.named_modules():
mod_type = type(module)

# Avoid duplicate registration: skip if we already processed this type
# in this walk, or if it was previously registered in the QuantModuleRegistry.
if mod_type in registered_types or QuantModuleRegistry.get(mod_type) is not None:
continue

if _is_sparse_moe_block(module):
print(
f"\033[1mDetected MOE module '{name}' of type {mod_type.__name__}, "
f"registering with _QuantSparseMoe.\033[0m"
)
QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantSparseMoe)
registered_types.add(mod_type)


def _is_supported_hf_model(model):
Expand Down Expand Up @@ -1065,7 +1095,7 @@ def _is_param_grad_enabled_for_auto_quantize(pname, model):
[
register_falcon_linears_on_the_fly,
register_dbrx_moe_on_the_fly,
register_minimax_m2_moe_on_the_fly,
register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
]
Expand Down
Loading