Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -639,16 +639,19 @@ def export_quantized(
extra_state_dict=mtp_state_dict,
)

# Copy custom model files (Python files and JSON configs) if trust_remote_code is used
copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)

# Restore default padding and export the tokenizer as well.
if tokenizer is not None:
tokenizer.padding_side = default_padding_side
if default_pad_token is not None:
tokenizer.pad_token = default_pad_token
tokenizer.save_pretrained(export_path)

# Copy custom model files (Python files and JSON configs) if trust_remote_code is used.
# This must run AFTER tokenizer.save_pretrained() so original tokenizer files
# from the source checkpoint take precedence over regenerated ones (which may
# differ in format due to newer transformers versions).
copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code)

end_time = time.time()
print(
f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s"
Expand Down
9 changes: 7 additions & 2 deletions modelopt/torch/export/layer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ def is_moe(module: nn.Module) -> bool:
"Qwen2MoeSparseMoeBlock".lower(),
"Qwen3MoeSparseMoeBlock".lower(),
"Qwen3NextSparseMoeBlock".lower(),
"Qwen3_5MoeSparseMoeBlock".lower(),
]
)

Expand Down Expand Up @@ -999,6 +1000,7 @@ def module_match_name_list(module, name_list):
"Qwen2MoeSparseMoeBlock",
"Qwen3MoeSparseMoeBlock",
"Qwen3NextSparseMoeBlock",
"Qwen3_5MoeSparseMoeBlock",
"DeepseekMoE",
],
):
Expand Down Expand Up @@ -1134,7 +1136,10 @@ def set_expert_quantizer_amax(
# Apply target amax to quantizers that need it
for module, attr_name, quantizer in all_quantizers:
# Check if quantizer needs amax (use property for consistency)
needs_amax = getattr(quantizer, "amax", None) is None
# Also treat zero amax as needing recalibration — a zero amax is never valid
# and indicates the quantizer wasn't activated during calibration
amax = getattr(quantizer, "amax", None)
needs_amax = amax is None or (isinstance(amax, torch.Tensor) and torch.all(amax == 0))

# Skip dynamic quantizers for input quantizers
if "input_quantizer" in attr_name and getattr(quantizer, "_dynamic", False):
Expand Down Expand Up @@ -1740,7 +1745,7 @@ def _split_fused_qkv_weight_and_scaling(

qkv_in = weight.shape[-1] if weight_dim > 1 else 1

num_kv_heads = num_kv_heads if num_kv_heads else num_heads
num_kv_heads = num_kv_heads or num_heads
assert num_heads % num_kv_heads == 0, (
f"num_heads({num_heads}) must be divisible by num_kv_heads({num_kv_heads}))."
)
Expand Down
69 changes: 61 additions & 8 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ def _process_quantized_modules(
"""
fsdp_module_to_reshard = None

for _, sub_module in model.named_modules():
for name, sub_module in model.named_modules():
# Optimization to perform resharding only once per decoder layer to avoid extra communication overhead
if isinstance(sub_module, FSDPModule):
# Every time we encounter a new FSDPModule, the previous decoder layer is fully processed.
Expand All @@ -593,8 +593,13 @@ def _process_quantized_modules(
sub_module.unpack_weight()
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
if is_quantlinear(sub_module):
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_quantized_weight(sub_module, dtype)
try:
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_quantized_weight(sub_module, dtype)
except AssertionError as e:
raise AssertionError(
f"Failed to export module '{name}' (type={type(sub_module).__name__}): {e}"
) from e
elif (
"Llama4TextExperts" in type(sub_module).__name__
or "GptOssExperts" in type(sub_module).__name__
Expand Down Expand Up @@ -954,6 +959,45 @@ def _export_diffusers_checkpoint(
print(f"Export complete. Saved to: {export_dir}")


def _revert_weight_conversion_noop(model: Any, state_dict: dict) -> dict:
"""No-op replacement for transformers' revert_weight_conversion."""
return state_dict


def _try_patch_module(mod_path: str) -> tuple[Any, Any] | None:
"""Try to patch revert_weight_conversion in a single module."""
import importlib

try:
mod = importlib.import_module(mod_path)
if hasattr(mod, "revert_weight_conversion"):
original = getattr(mod, "revert_weight_conversion")
setattr(mod, "revert_weight_conversion", _revert_weight_conversion_noop)
return (mod, original)
except (ImportError, AttributeError):
pass
return None


def _patch_revert_weight_conversion() -> list[tuple[Any, Any]]:
"""Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors."""
patches: list[tuple[Any, Any]] = []
for mod_path in [
"transformers.core_model_loading",
"transformers.modeling_utils",
]:
result = _try_patch_module(mod_path)
if result is not None:
patches.append(result)
return patches


def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None:
"""Restore the original revert_weight_conversion functions."""
for mod, original in patches:
mod.revert_weight_conversion = original


def export_hf_checkpoint(
model: Any,
dtype: torch.dtype | None = None,
Expand Down Expand Up @@ -1013,11 +1057,20 @@ def export_hf_checkpoint(
model.hf_quantizer = None

# Save model
model.save_pretrained(
export_dir,
state_dict={**post_state_dict, **(extra_state_dict or {})},
save_modelopt_state=save_modelopt_state,
)
# Temporarily disable revert_weight_conversion if available — it doesn't handle
# quantized state dicts (scalar scale tensors have 0 dimensions, causing IndexError).
# We must patch both the source module and the importing module since
# modeling_utils does `from core_model_loading import revert_weight_conversion`.
_patches = _patch_revert_weight_conversion()

try:
model.save_pretrained(
export_dir,
state_dict={**post_state_dict, **(extra_state_dict or {})},
save_modelopt_state=save_modelopt_state,
)
finally:
_unpatch_revert_weight_conversion(_patches)

original_config = f"{export_dir}/config.json"
config_data = {}
Expand Down
141 changes: 141 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,6 +653,107 @@ def forward(
return next_states


class _Qwen35MoeExpertModule(nn.Module):
"""Container for a single Qwen3.5 MoE expert's linear layers.

Produces the naming pattern: experts.{id}.gate_proj.weight
(consistent with standard Qwen3 MoE per-expert module structure).
"""

def __init__(self, hidden_dim: int, expert_dim: int):
super().__init__()
self.gate_proj = nn.Linear(hidden_dim, expert_dim, bias=False)
self.up_proj = nn.Linear(hidden_dim, expert_dim, bias=False)
self.down_proj = nn.Linear(expert_dim, hidden_dim, bias=False)


class _QuantQwen35MoeExperts(QuantModule):
def _setup(self):
"""Modify the Qwen3_5MoeExperts by using per-expert nn.Module containers.

This produces the naming pattern: experts.{id}.gate_proj.weight
(consistent with standard Qwen3 MoE).
"""
from accelerate import init_empty_weights

dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device

def _copy_weight(module, weight):
module.to_empty(device=device)
with torch.no_grad():
module.weight.data = weight.detach().data.to(dtype=dtype, device=device)

expert_dim = self.intermediate_dim

with init_empty_weights():
expert_modules = nn.ModuleList(
[
_Qwen35MoeExpertModule(self.hidden_dim, expert_dim)
for _ in range(self.num_experts)
]
)

for idx in range(self.num_experts):
# gate_up_proj shape: (num_experts, 2*intermediate_dim, hidden_dim)
# Already in (out_features, in_features) format, no transpose needed
_copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :expert_dim, :])
_copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, expert_dim:, :])
# down_proj shape: (num_experts, hidden_dim, intermediate_dim)
# Already in (out_features, in_features) format
_copy_weight(expert_modules[idx].down_proj, self.down_proj[idx])

delattr(self, "gate_up_proj")
delattr(self, "down_proj")
# Register expert modules directly as numbered children (like nn.ModuleList)
# so the naming pattern is: experts.{id}.gate_proj.weight (no extra nesting)
for idx in range(self.num_experts):
self.add_module(str(idx), expert_modules[idx])

def __len__(self):
"""Support len() so the module is iterable like standard MoE experts."""
return self.num_experts

def __iter__(self):
"""Support iteration over expert modules."""
for idx in range(self.num_experts):
yield getattr(self, str(idx))

def __getitem__(self, idx):
"""Support indexing to get individual expert modules."""
return getattr(self, str(int(idx)))

def forward(
self,
hidden_states: torch.Tensor,
top_k_index: torch.Tensor,
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == self.num_experts:
continue
with torch.no_grad():
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
expert = self[expert_idx]
gate = expert.gate_proj(current_state)
up = expert.up_proj(current_state)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = expert.down_proj(current_hidden_states)
current_hidden_states = (
current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
)
final_hidden_states.index_add_(
0, token_idx, current_hidden_states.to(final_hidden_states.dtype)
)
return final_hidden_states


class _QuantDbrxFFN(_QuantSparseMoe):
@property
def num_experts(self):
Expand Down Expand Up @@ -797,6 +898,46 @@ def unpack_weight(self):
pass


class _QuantQwen35MoeSparseMoeBlock(_QuantSparseMoe):
"""Qwen3.5 MoE stores top_k/num_experts in the router (self.gate), not as direct attributes.

We override forward instead of just bridging attributes because the router (self.gate)
uses its own top_k internally for routing decisions. We must modify self.gate.top_k
directly so all experts see calibration data.
"""

def _setup(self):
self.num_experts = self.experts.num_experts

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
if any(getattr(m, "_if_calib", False) for m in self.experts.modules()):
# Force all tokens to all experts during calibration
original_top_k = self.gate.top_k
self.gate.top_k = self.num_experts
super(_QuantSparseMoe, self).forward(hidden_states)
self.gate.top_k = original_top_k
return super(_QuantSparseMoe, self).forward(hidden_states)


try:
from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import (
Qwen3_5MoeExperts,
Qwen3_5MoeSparseMoeBlock,
)

if Qwen3_5MoeSparseMoeBlock not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3_5MoeSparseMoeBlock: "hf.Qwen3_5MoeSparseMoeBlock"})(
_QuantQwen35MoeSparseMoeBlock
)

if Qwen3_5MoeExperts not in QuantModuleRegistry:
QuantModuleRegistry.register({Qwen3_5MoeExperts: "hf.Qwen3_5MoeExperts"})(
_QuantQwen35MoeExperts
)
except ImportError:
pass


class _QuantGptOssExperts(_QuantFunctionalMixin):
"""Quantized wrapper for `transformers.GptOssExperts`.

Expand Down
4 changes: 2 additions & 2 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def get_dataset_dataloader(
A instance of dataloader.
"""
assert tokenizer is not None, "Please provide a tokenizer."
# batch_encode_plus will modify the tokenizer in place, so we need to clone it.
# Tokenizer encoding may modify the tokenizer in place, so we need to clone it.
tokenizer = copy.deepcopy(tokenizer)

if tokenizer.padding_side != "left":
Expand All @@ -247,7 +247,7 @@ def get_dataset_dataloader(
samples = get_dataset_samples(ds_name, num_sample)
all_samples.extend(samples)

batch_encoded = tokenizer.batch_encode_plus(
batch_encoded = tokenizer(
all_samples,
return_tensors="pt",
padding=True,
Expand Down