diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d9a6ca893..fc1043c05 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -639,9 +639,6 @@ def export_quantized( extra_state_dict=mtp_state_dict, ) - # Copy custom model files (Python files and JSON configs) if trust_remote_code is used - copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code) - # Restore default padding and export the tokenizer as well. if tokenizer is not None: tokenizer.padding_side = default_padding_side @@ -649,6 +646,12 @@ def export_quantized( tokenizer.pad_token = default_pad_token tokenizer.save_pretrained(export_path) + # Copy custom model files (Python files and JSON configs) if trust_remote_code is used. + # This must run AFTER tokenizer.save_pretrained() so original tokenizer files + # from the source checkpoint take precedence over regenerated ones (which may + # differ in format due to newer transformers versions). + copy_custom_model_files(args.pyt_ckpt_path, export_path, args.trust_remote_code) + end_time = time.time() print( f"Quantized model exported to: {export_path}. Total time used {end_time - start_time}s" diff --git a/modelopt/torch/export/layer_utils.py b/modelopt/torch/export/layer_utils.py index 9346e074b..2eea6c052 100755 --- a/modelopt/torch/export/layer_utils.py +++ b/modelopt/torch/export/layer_utils.py @@ -339,6 +339,7 @@ def is_moe(module: nn.Module) -> bool: "Qwen2MoeSparseMoeBlock".lower(), "Qwen3MoeSparseMoeBlock".lower(), "Qwen3NextSparseMoeBlock".lower(), + "Qwen3_5MoeSparseMoeBlock".lower(), ] ) @@ -999,6 +1000,7 @@ def module_match_name_list(module, name_list): "Qwen2MoeSparseMoeBlock", "Qwen3MoeSparseMoeBlock", "Qwen3NextSparseMoeBlock", + "Qwen3_5MoeSparseMoeBlock", "DeepseekMoE", ], ): @@ -1134,7 +1136,10 @@ def set_expert_quantizer_amax( # Apply target amax to quantizers that need it for module, attr_name, quantizer in all_quantizers: # Check if quantizer needs amax (use property for consistency) - needs_amax = getattr(quantizer, "amax", None) is None + # Also treat zero amax as needing recalibration — a zero amax is never valid + # and indicates the quantizer wasn't activated during calibration + amax = getattr(quantizer, "amax", None) + needs_amax = amax is None or (isinstance(amax, torch.Tensor) and torch.all(amax == 0)) # Skip dynamic quantizers for input quantizers if "input_quantizer" in attr_name and getattr(quantizer, "_dynamic", False): @@ -1740,7 +1745,7 @@ def _split_fused_qkv_weight_and_scaling( qkv_in = weight.shape[-1] if weight_dim > 1 else 1 - num_kv_heads = num_kv_heads if num_kv_heads else num_heads + num_kv_heads = num_kv_heads or num_heads assert num_heads % num_kv_heads == 0, ( f"num_heads({num_heads}) must be divisible by num_kv_heads({num_kv_heads}))." ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 5703f4515..3235c5d2f 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -574,7 +574,7 @@ def _process_quantized_modules( """ fsdp_module_to_reshard = None - for _, sub_module in model.named_modules(): + for name, sub_module in model.named_modules(): # Optimization to perform resharding only once per decoder layer to avoid extra communication overhead if isinstance(sub_module, FSDPModule): # Every time we encounter a new FSDPModule, the previous decoder layer is fully processed. @@ -593,8 +593,13 @@ def _process_quantized_modules( sub_module.unpack_weight() if get_quantization_format(sub_module) != QUANTIZATION_NONE: if is_quantlinear(sub_module): - with fsdp2_aware_weight_update(model, sub_module, reshard=False): - _export_quantized_weight(sub_module, dtype) + try: + with fsdp2_aware_weight_update(model, sub_module, reshard=False): + _export_quantized_weight(sub_module, dtype) + except AssertionError as e: + raise AssertionError( + f"Failed to export module '{name}' (type={type(sub_module).__name__}): {e}" + ) from e elif ( "Llama4TextExperts" in type(sub_module).__name__ or "GptOssExperts" in type(sub_module).__name__ @@ -954,6 +959,45 @@ def _export_diffusers_checkpoint( print(f"Export complete. Saved to: {export_dir}") +def _revert_weight_conversion_noop(model: Any, state_dict: dict) -> dict: + """No-op replacement for transformers' revert_weight_conversion.""" + return state_dict + + +def _try_patch_module(mod_path: str) -> tuple[Any, Any] | None: + """Try to patch revert_weight_conversion in a single module.""" + import importlib + + try: + mod = importlib.import_module(mod_path) + if hasattr(mod, "revert_weight_conversion"): + original = getattr(mod, "revert_weight_conversion") + setattr(mod, "revert_weight_conversion", _revert_weight_conversion_noop) + return (mod, original) + except (ImportError, AttributeError): + pass + return None + + +def _patch_revert_weight_conversion() -> list[tuple[Any, Any]]: + """Patch revert_weight_conversion in transformers to avoid IndexError on scalar tensors.""" + patches: list[tuple[Any, Any]] = [] + for mod_path in [ + "transformers.core_model_loading", + "transformers.modeling_utils", + ]: + result = _try_patch_module(mod_path) + if result is not None: + patches.append(result) + return patches + + +def _unpatch_revert_weight_conversion(patches: list[tuple[Any, Any]]) -> None: + """Restore the original revert_weight_conversion functions.""" + for mod, original in patches: + mod.revert_weight_conversion = original + + def export_hf_checkpoint( model: Any, dtype: torch.dtype | None = None, @@ -1013,11 +1057,20 @@ def export_hf_checkpoint( model.hf_quantizer = None # Save model - model.save_pretrained( - export_dir, - state_dict={**post_state_dict, **(extra_state_dict or {})}, - save_modelopt_state=save_modelopt_state, - ) + # Temporarily disable revert_weight_conversion if available — it doesn't handle + # quantized state dicts (scalar scale tensors have 0 dimensions, causing IndexError). + # We must patch both the source module and the importing module since + # modeling_utils does `from core_model_loading import revert_weight_conversion`. + _patches = _patch_revert_weight_conversion() + + try: + model.save_pretrained( + export_dir, + state_dict={**post_state_dict, **(extra_state_dict or {})}, + save_modelopt_state=save_modelopt_state, + ) + finally: + _unpatch_revert_weight_conversion(_patches) original_config = f"{export_dir}/config.json" config_data = {} diff --git a/modelopt/torch/quantization/plugins/huggingface.py b/modelopt/torch/quantization/plugins/huggingface.py index a29d7c754..5b49c93fe 100644 --- a/modelopt/torch/quantization/plugins/huggingface.py +++ b/modelopt/torch/quantization/plugins/huggingface.py @@ -653,6 +653,107 @@ def forward( return next_states +class _Qwen35MoeExpertModule(nn.Module): + """Container for a single Qwen3.5 MoE expert's linear layers. + + Produces the naming pattern: experts.{id}.gate_proj.weight + (consistent with standard Qwen3 MoE per-expert module structure). + """ + + def __init__(self, hidden_dim: int, expert_dim: int): + super().__init__() + self.gate_proj = nn.Linear(hidden_dim, expert_dim, bias=False) + self.up_proj = nn.Linear(hidden_dim, expert_dim, bias=False) + self.down_proj = nn.Linear(expert_dim, hidden_dim, bias=False) + + +class _QuantQwen35MoeExperts(QuantModule): + def _setup(self): + """Modify the Qwen3_5MoeExperts by using per-expert nn.Module containers. + + This produces the naming pattern: experts.{id}.gate_proj.weight + (consistent with standard Qwen3 MoE). + """ + from accelerate import init_empty_weights + + dtype, device = self.gate_up_proj.dtype, self.gate_up_proj.device + + def _copy_weight(module, weight): + module.to_empty(device=device) + with torch.no_grad(): + module.weight.data = weight.detach().data.to(dtype=dtype, device=device) + + expert_dim = self.intermediate_dim + + with init_empty_weights(): + expert_modules = nn.ModuleList( + [ + _Qwen35MoeExpertModule(self.hidden_dim, expert_dim) + for _ in range(self.num_experts) + ] + ) + + for idx in range(self.num_experts): + # gate_up_proj shape: (num_experts, 2*intermediate_dim, hidden_dim) + # Already in (out_features, in_features) format, no transpose needed + _copy_weight(expert_modules[idx].gate_proj, self.gate_up_proj[idx, :expert_dim, :]) + _copy_weight(expert_modules[idx].up_proj, self.gate_up_proj[idx, expert_dim:, :]) + # down_proj shape: (num_experts, hidden_dim, intermediate_dim) + # Already in (out_features, in_features) format + _copy_weight(expert_modules[idx].down_proj, self.down_proj[idx]) + + delattr(self, "gate_up_proj") + delattr(self, "down_proj") + # Register expert modules directly as numbered children (like nn.ModuleList) + # so the naming pattern is: experts.{id}.gate_proj.weight (no extra nesting) + for idx in range(self.num_experts): + self.add_module(str(idx), expert_modules[idx]) + + def __len__(self): + """Support len() so the module is iterable like standard MoE experts.""" + return self.num_experts + + def __iter__(self): + """Support iteration over expert modules.""" + for idx in range(self.num_experts): + yield getattr(self, str(idx)) + + def __getitem__(self, idx): + """Support indexing to get individual expert modules.""" + return getattr(self, str(int(idx))) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + final_hidden_states = torch.zeros_like(hidden_states) + with torch.no_grad(): + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() + for expert_idx in expert_hit: + expert_idx = expert_idx[0] + if expert_idx == self.num_experts: + continue + with torch.no_grad(): + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + expert = self[expert_idx] + gate = expert.gate_proj(current_state) + up = expert.up_proj(current_state) + current_hidden_states = self.act_fn(gate) * up + current_hidden_states = expert.down_proj(current_hidden_states) + current_hidden_states = ( + current_hidden_states * top_k_weights[token_idx, top_k_pos, None] + ) + final_hidden_states.index_add_( + 0, token_idx, current_hidden_states.to(final_hidden_states.dtype) + ) + return final_hidden_states + + class _QuantDbrxFFN(_QuantSparseMoe): @property def num_experts(self): @@ -797,6 +898,46 @@ def unpack_weight(self): pass +class _QuantQwen35MoeSparseMoeBlock(_QuantSparseMoe): + """Qwen3.5 MoE stores top_k/num_experts in the router (self.gate), not as direct attributes. + + We override forward instead of just bridging attributes because the router (self.gate) + uses its own top_k internally for routing decisions. We must modify self.gate.top_k + directly so all experts see calibration data. + """ + + def _setup(self): + self.num_experts = self.experts.num_experts + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + if any(getattr(m, "_if_calib", False) for m in self.experts.modules()): + # Force all tokens to all experts during calibration + original_top_k = self.gate.top_k + self.gate.top_k = self.num_experts + super(_QuantSparseMoe, self).forward(hidden_states) + self.gate.top_k = original_top_k + return super(_QuantSparseMoe, self).forward(hidden_states) + + +try: + from transformers.models.qwen3_5_moe.modeling_qwen3_5_moe import ( + Qwen3_5MoeExperts, + Qwen3_5MoeSparseMoeBlock, + ) + + if Qwen3_5MoeSparseMoeBlock not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen3_5MoeSparseMoeBlock: "hf.Qwen3_5MoeSparseMoeBlock"})( + _QuantQwen35MoeSparseMoeBlock + ) + + if Qwen3_5MoeExperts not in QuantModuleRegistry: + QuantModuleRegistry.register({Qwen3_5MoeExperts: "hf.Qwen3_5MoeExperts"})( + _QuantQwen35MoeExperts + ) +except ImportError: + pass + + class _QuantGptOssExperts(_QuantFunctionalMixin): """Quantized wrapper for `transformers.GptOssExperts`. diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 16bff49c2..7718b2126 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -224,7 +224,7 @@ def get_dataset_dataloader( A instance of dataloader. """ assert tokenizer is not None, "Please provide a tokenizer." - # batch_encode_plus will modify the tokenizer in place, so we need to clone it. + # Tokenizer encoding may modify the tokenizer in place, so we need to clone it. tokenizer = copy.deepcopy(tokenizer) if tokenizer.padding_side != "left": @@ -247,7 +247,7 @@ def get_dataset_dataloader( samples = get_dataset_samples(ds_name, num_sample) all_samples.extend(samples) - batch_encoded = tokenizer.batch_encode_plus( + batch_encoded = tokenizer( all_samples, return_tensors="pt", padding=True,