diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index d5bab9b4e..007efe1b0 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -52,6 +52,8 @@ "LlamaForCausalLMEagle3Deep": eagle3_deep_llama_causal_lm_export, "Qwen3ForCausalLM": qwen3_causal_lm_export, "Qwen3MoeForCausalLM": qwen3_causal_lm_export, + "Qwen3VLForConditionalGeneration": qwen3_causal_lm_export, + "Qwen3VLMoeForConditionalGeneration": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, "GptOssForCausalLM": gptoss_causal_lm_export, } diff --git a/modelopt/torch/export/unified_export_megatron.py b/modelopt/torch/export/unified_export_megatron.py index 8a6d76b34..39dbab3ad 100644 --- a/modelopt/torch/export/unified_export_megatron.py +++ b/modelopt/torch/export/unified_export_megatron.py @@ -79,12 +79,97 @@ has_mcore = True +Qwen3VLModel = None +try: + from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel +except ImportError: + pass + __all__ = [ "export_mcore_gpt_to_hf", "import_mcore_gpt_from_hf", ] +class _FusedLayerNormProxy(torch.nn.Module): + """Proxy module exposing fused layernorm weights from TELayerNormColumnParallelLinear. + + When using TE spec, the input layernorm and pre-MLP layernorm are fused into the + subsequent linear layer (TELayerNormColumnParallelLinear). The layernorm weight is + stored as ``layer_norm_weight`` on the fused linear module rather than as a separate + ``weight`` on a standalone layernorm module. + + This proxy wraps that fused weight so the existing export rules (which expect a + module with a ``.weight`` attribute) can export it with the correct HF key name. + """ + + def __init__(self, fused_linear: torch.nn.Module): + super().__init__() + self.weight = fused_linear.layer_norm_weight + bias = getattr(fused_linear, "layer_norm_bias", None) + if bias is not None: + self.bias = bias + + +class _MoEExpertConfigProxy: + """Proxy that presents ``moe_ffn_hidden_size`` as ``ffn_hidden_size``. + + ``SequentialMLP`` deep-copies the ``TransformerConfig`` and overrides + ``ffn_hidden_size = moe_ffn_hidden_size`` so that each expert MLP (and the + rule function ``_gated_mlp_slicing``) sees the correct value via + ``module.config.ffn_hidden_size``. + + ``TEGroupedMLP`` does **not** perform that override, so its + ``config.ffn_hidden_size`` still holds the dense-MLP size. This proxy + bridges the gap by returning ``moe_ffn_hidden_size`` when + ``ffn_hidden_size`` is accessed, and delegates everything else to the + original config. + """ + + def __init__(self, config): + object.__setattr__(self, "_config", config) + object.__setattr__( + self, + "ffn_hidden_size", + getattr(config, "moe_ffn_hidden_size", config.ffn_hidden_size), + ) + + def __getattr__(self, name): + return getattr(self._config, name) + + +class _GroupedLinearExpertProxy: + """Present a single expert's weight slice from a TE GroupedLinear module. + + TE ``GroupedLinear`` stores all expert weights as ``weight0``, ``weight1``, + …, ``weight{n-1}`` and shares a single ``weight_quantizer`` / + ``input_quantizer`` across experts. The existing export rule functions + (``_name_remapping``, ``_gated_mlp_slicing``, …) expect a module with a + single ``.weight`` attribute and per-module quantisers. + + This lightweight proxy satisfies that contract for a single expert by: + + * Exposing ``weight{expert_id}`` as ``.weight`` + * Attaching a ``config`` proxy (``_MoEExpertConfigProxy``) so that + ``module.config.ffn_hidden_size`` returns ``moe_ffn_hidden_size`` + * Delegating every other attribute (``weight_quantizer``, + ``input_quantizer``, …) to the underlying ``GroupedLinear``. + """ + + def __init__(self, grouped_linear, expert_id, config): + object.__setattr__(self, "_grouped_linear", grouped_linear) + object.__setattr__(self, "_expert_id", expert_id) + object.__setattr__(self, "config", _MoEExpertConfigProxy(config)) + # Expose the individual expert weight as .weight + object.__setattr__( + self, "weight", getattr(grouped_linear, f"weight{expert_id}") + ) + + def __getattr__(self, name): + # Delegate quantizer attrs, bias, etc. to the GroupedLinear module + return getattr(self._grouped_linear, name) + + class GPTModelExporter: """Megatron Core GPTModel Exporter. @@ -115,7 +200,10 @@ def __init__( moe_router_dtype: str | None = None, ): """Create a GPTModel exporter instance.""" - if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)): + _supported_types = (GPTModel, MambaModel, LLaVAModel) + if Qwen3VLModel is not None: + _supported_types = _supported_types + (Qwen3VLModel,) + if not isinstance(model, _supported_types): raise ValueError("Input to GPTModelExport must be a megatron.core.models.GPTModel!") self._state_dict = OrderedDict() @@ -139,13 +227,14 @@ def __init__( self._hf_text_config.head_dim = model.config.kv_channels self._hf_text_config.num_attention_heads = model.config.num_attention_heads self._hf_text_config.num_key_value_heads = model.config.num_query_groups - self.is_multimodal = isinstance(model, LLaVAModel) + self.is_multimodal = isinstance(model, LLaVAModel) or ( + Qwen3VLModel is not None and isinstance(model, Qwen3VLModel) + ) if not self.is_multimodal: self._hf_text_config.intermediate_size = model.config.ffn_hidden_size self._hf_quant_config: dict = {} self._hf_extra_config = None self.export_extra_modules = export_extra_modules - self.is_multimodal = isinstance(model, LLaVAModel) self.model = model.language_model if self.is_multimodal else model self.dtype = dtype self.trust_remote_code = trust_remote_code @@ -489,6 +578,17 @@ def _get_state_dict(self): def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.input_layernorm, IdentityOp): self.rules["input_layernorm"](layer.input_layernorm, layer_id) + elif ( + "input_layernorm" in self.rules + and hasattr(layer, "self_attention") + and not isinstance(layer.self_attention, IdentityOp) + and hasattr(layer.self_attention, "linear_qkv") + and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight") + ): + # TE spec: input layernorm is fused into TELayerNormColumnParallelLinear + self.rules["input_layernorm"]( + _FusedLayerNormProxy(layer.self_attention.linear_qkv), layer_id + ) if not isinstance(layer.self_attention, IdentityOp): if "MLASelfAttention" in str(type(layer.self_attention)): @@ -527,6 +627,14 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): if not isinstance(layer.pre_mlp_layernorm, IdentityOp): self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id) + elif ( + "pre_mlp_layernorm" in self.rules + and not isinstance(layer.mlp, IdentityOp) + and hasattr(layer.mlp, "linear_fc1") + and hasattr(layer.mlp.linear_fc1, "layer_norm_weight") + ): + # TE spec: pre-MLP layernorm is fused into TELayerNormColumnParallelLinear + self.rules["pre_mlp_layernorm"](_FusedLayerNormProxy(layer.mlp.linear_fc1), layer_id) if not isinstance(layer.mlp, IdentityOp): if "MoE" in str(type(layer.mlp)): @@ -542,22 +650,44 @@ def _get_transformer_layer_state_dict(self, layer, layer_id): self.rules["shared_experts.linear_fc2"]( layer.mlp.shared_experts.linear_fc2, layer_id ) - if not self.rules.get("use_packed_local_experts", False): - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + if hasattr(layer.mlp.experts, "local_experts"): + # SequentialMLP: each expert is an individual MLP module + if not self.rules.get("use_packed_local_experts", False): + for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + self.rules["local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) + else: + # For llama 4, in hf unified checkpoint, all local experts share one scale self.rules["local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id + layer.mlp.experts.local_experts, layer_id ) self.rules["local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id + layer.mlp.experts.local_experts, layer_id ) else: - # For llama 4, in hf unified checkpoint, all local experts share one scale - self.rules["local_experts.linear_fc1"]( - layer.mlp.experts.local_experts, layer_id - ) - self.rules["local_experts.linear_fc2"]( - layer.mlp.experts.local_experts, layer_id - ) + # GroupedMLP / TEGroupedMLP: experts are fused into + # linear_fc1 and linear_fc2 (TE GroupedLinear) with + # per-expert weights stored as weight0, weight1, ... + experts_module = layer.mlp.experts + num_experts = experts_module.num_local_experts + expert_config = experts_module.config + for expert_id in range(num_experts): + fc1_proxy = _GroupedLinearExpertProxy( + experts_module.linear_fc1, expert_id, expert_config + ) + fc2_proxy = _GroupedLinearExpertProxy( + experts_module.linear_fc2, expert_id, expert_config + ) + self.rules["local_experts.linear_fc1"]( + fc1_proxy, layer_id, expert_id + ) + self.rules["local_experts.linear_fc2"]( + fc2_proxy, layer_id, expert_id + ) else: self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id) @@ -598,6 +728,14 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]: def _get_mamba_layer_state_dict(self, layer, layer_id): if not isinstance(layer.norm, IdentityOp): self.rules["norm"](layer.norm, layer_id) + elif ( + "norm" in self.rules + and hasattr(layer, "mixer") + and hasattr(layer.mixer, "in_proj") + and hasattr(layer.mixer.in_proj, "layer_norm_weight") + ): + # TE spec: norm is fused into TELayerNormColumnParallelLinear (in_proj) + self.rules["norm"](_FusedLayerNormProxy(layer.mixer.in_proj), layer_id) self.rules["mixer_norm"](layer.mixer.norm, layer_id) self.rules["A_log"](layer.mixer.A_log, layer_id) @@ -695,13 +833,31 @@ def _get_eagle_module_state_dict(self): self.rules["eagle_module.shared_experts.linear_fc2"]( layer.mlp.shared_experts.linear_fc2, layer_id ) - for expert_id, expert in enumerate(layer.mlp.experts.local_experts): - self.rules["eagle_module.local_experts.linear_fc1"]( - expert.linear_fc1, layer_id, expert_id - ) - self.rules["eagle_module.local_experts.linear_fc2"]( - expert.linear_fc2, layer_id, expert_id - ) + if hasattr(layer.mlp.experts, "local_experts"): + for expert_id, expert in enumerate(layer.mlp.experts.local_experts): + self.rules["eagle_module.local_experts.linear_fc1"]( + expert.linear_fc1, layer_id, expert_id + ) + self.rules["eagle_module.local_experts.linear_fc2"]( + expert.linear_fc2, layer_id, expert_id + ) + else: + experts_module = layer.mlp.experts + num_experts = experts_module.num_local_experts + expert_config = experts_module.config + for expert_id in range(num_experts): + fc1_proxy = _GroupedLinearExpertProxy( + experts_module.linear_fc1, expert_id, expert_config + ) + fc2_proxy = _GroupedLinearExpertProxy( + experts_module.linear_fc2, expert_id, expert_config + ) + self.rules["eagle_module.local_experts.linear_fc1"]( + fc1_proxy, layer_id, expert_id + ) + self.rules["eagle_module.local_experts.linear_fc2"]( + fc2_proxy, layer_id, expert_id + ) else: self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id) self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id)