Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions modelopt/torch/export/plugins/mcore_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@
"LlamaForCausalLMEagle3Deep": eagle3_deep_llama_causal_lm_export,
"Qwen3ForCausalLM": qwen3_causal_lm_export,
"Qwen3MoeForCausalLM": qwen3_causal_lm_export,
"Qwen3VLForConditionalGeneration": qwen3_causal_lm_export,
"Qwen3VLMoeForConditionalGeneration": qwen3_causal_lm_export,
"Qwen2ForCausalLM": qwen25_causal_lm_export,
"GptOssForCausalLM": gptoss_causal_lm_export,
}
Expand Down
198 changes: 177 additions & 21 deletions modelopt/torch/export/unified_export_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,97 @@

has_mcore = True

Qwen3VLModel = None
try:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does Qwen3VLModel exist in mcore? the other imports were from mcore. just wondering if this will cause a circular dependency since MBridge also depends on modelopt

from megatron.bridge.models.qwen_vl.modelling_qwen3_vl.model import Qwen3VLModel
except ImportError:
pass

__all__ = [
"export_mcore_gpt_to_hf",
"import_mcore_gpt_from_hf",
]


class _FusedLayerNormProxy(torch.nn.Module):
"""Proxy module exposing fused layernorm weights from TELayerNormColumnParallelLinear.

When using TE spec, the input layernorm and pre-MLP layernorm are fused into the
subsequent linear layer (TELayerNormColumnParallelLinear). The layernorm weight is
stored as ``layer_norm_weight`` on the fused linear module rather than as a separate
``weight`` on a standalone layernorm module.

This proxy wraps that fused weight so the existing export rules (which expect a
module with a ``.weight`` attribute) can export it with the correct HF key name.
"""

def __init__(self, fused_linear: torch.nn.Module):
super().__init__()
self.weight = fused_linear.layer_norm_weight
bias = getattr(fused_linear, "layer_norm_bias", None)
if bias is not None:
self.bias = bias


class _MoEExpertConfigProxy:
"""Proxy that presents ``moe_ffn_hidden_size`` as ``ffn_hidden_size``.

``SequentialMLP`` deep-copies the ``TransformerConfig`` and overrides
``ffn_hidden_size = moe_ffn_hidden_size`` so that each expert MLP (and the
rule function ``_gated_mlp_slicing``) sees the correct value via
``module.config.ffn_hidden_size``.

``TEGroupedMLP`` does **not** perform that override, so its
``config.ffn_hidden_size`` still holds the dense-MLP size. This proxy
bridges the gap by returning ``moe_ffn_hidden_size`` when
``ffn_hidden_size`` is accessed, and delegates everything else to the
original config.
"""

def __init__(self, config):
object.__setattr__(self, "_config", config)
object.__setattr__(
self,
"ffn_hidden_size",
getattr(config, "moe_ffn_hidden_size", config.ffn_hidden_size),
)

def __getattr__(self, name):
return getattr(self._config, name)


class _GroupedLinearExpertProxy:
"""Present a single expert's weight slice from a TE GroupedLinear module.

TE ``GroupedLinear`` stores all expert weights as ``weight0``, ``weight1``,
…, ``weight{n-1}`` and shares a single ``weight_quantizer`` /
``input_quantizer`` across experts. The existing export rule functions
(``_name_remapping``, ``_gated_mlp_slicing``, …) expect a module with a
single ``.weight`` attribute and per-module quantisers.

This lightweight proxy satisfies that contract for a single expert by:

* Exposing ``weight{expert_id}`` as ``.weight``
* Attaching a ``config`` proxy (``_MoEExpertConfigProxy``) so that
``module.config.ffn_hidden_size`` returns ``moe_ffn_hidden_size``
* Delegating every other attribute (``weight_quantizer``,
``input_quantizer``, …) to the underlying ``GroupedLinear``.
"""

def __init__(self, grouped_linear, expert_id, config):
object.__setattr__(self, "_grouped_linear", grouped_linear)
object.__setattr__(self, "_expert_id", expert_id)
object.__setattr__(self, "config", _MoEExpertConfigProxy(config))
# Expose the individual expert weight as .weight
object.__setattr__(
self, "weight", getattr(grouped_linear, f"weight{expert_id}")
)

def __getattr__(self, name):
# Delegate quantizer attrs, bias, etc. to the GroupedLinear module
return getattr(self._grouped_linear, name)


class GPTModelExporter:
"""Megatron Core GPTModel Exporter.

Expand Down Expand Up @@ -115,7 +200,10 @@ def __init__(
moe_router_dtype: str | None = None,
):
"""Create a GPTModel exporter instance."""
if not isinstance(model, (GPTModel, MambaModel, LLaVAModel)):
_supported_types = (GPTModel, MambaModel, LLaVAModel)
if Qwen3VLModel is not None:
_supported_types = _supported_types + (Qwen3VLModel,)
if not isinstance(model, _supported_types):
raise ValueError("Input to GPTModelExport must be a megatron.core.models.GPTModel!")

self._state_dict = OrderedDict()
Expand All @@ -139,13 +227,14 @@ def __init__(
self._hf_text_config.head_dim = model.config.kv_channels
self._hf_text_config.num_attention_heads = model.config.num_attention_heads
self._hf_text_config.num_key_value_heads = model.config.num_query_groups
self.is_multimodal = isinstance(model, LLaVAModel)
self.is_multimodal = isinstance(model, LLaVAModel) or (
Qwen3VLModel is not None and isinstance(model, Qwen3VLModel)
)
if not self.is_multimodal:
self._hf_text_config.intermediate_size = model.config.ffn_hidden_size
self._hf_quant_config: dict = {}
self._hf_extra_config = None
self.export_extra_modules = export_extra_modules
self.is_multimodal = isinstance(model, LLaVAModel)
self.model = model.language_model if self.is_multimodal else model
self.dtype = dtype
self.trust_remote_code = trust_remote_code
Expand Down Expand Up @@ -489,6 +578,17 @@ def _get_state_dict(self):
def _get_transformer_layer_state_dict(self, layer, layer_id):
if not isinstance(layer.input_layernorm, IdentityOp):
self.rules["input_layernorm"](layer.input_layernorm, layer_id)
elif (
"input_layernorm" in self.rules
and hasattr(layer, "self_attention")
and not isinstance(layer.self_attention, IdentityOp)
and hasattr(layer.self_attention, "linear_qkv")
and hasattr(layer.self_attention.linear_qkv, "layer_norm_weight")
):
# TE spec: input layernorm is fused into TELayerNormColumnParallelLinear
self.rules["input_layernorm"](
_FusedLayerNormProxy(layer.self_attention.linear_qkv), layer_id
)

if not isinstance(layer.self_attention, IdentityOp):
if "MLASelfAttention" in str(type(layer.self_attention)):
Expand Down Expand Up @@ -527,6 +627,14 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):

if not isinstance(layer.pre_mlp_layernorm, IdentityOp):
self.rules["pre_mlp_layernorm"](layer.pre_mlp_layernorm, layer_id)
elif (
"pre_mlp_layernorm" in self.rules
and not isinstance(layer.mlp, IdentityOp)
and hasattr(layer.mlp, "linear_fc1")
and hasattr(layer.mlp.linear_fc1, "layer_norm_weight")
):
# TE spec: pre-MLP layernorm is fused into TELayerNormColumnParallelLinear
self.rules["pre_mlp_layernorm"](_FusedLayerNormProxy(layer.mlp.linear_fc1), layer_id)

if not isinstance(layer.mlp, IdentityOp):
if "MoE" in str(type(layer.mlp)):
Expand All @@ -542,22 +650,44 @@ def _get_transformer_layer_state_dict(self, layer, layer_id):
self.rules["shared_experts.linear_fc2"](
layer.mlp.shared_experts.linear_fc2, layer_id
)
if not self.rules.get("use_packed_local_experts", False):
for expert_id, expert in enumerate(layer.mlp.experts.local_experts):
if hasattr(layer.mlp.experts, "local_experts"):
# SequentialMLP: each expert is an individual MLP module
if not self.rules.get("use_packed_local_experts", False):
for expert_id, expert in enumerate(layer.mlp.experts.local_experts):
self.rules["local_experts.linear_fc1"](
expert.linear_fc1, layer_id, expert_id
)
self.rules["local_experts.linear_fc2"](
expert.linear_fc2, layer_id, expert_id
)
else:
# For llama 4, in hf unified checkpoint, all local experts share one scale
self.rules["local_experts.linear_fc1"](
expert.linear_fc1, layer_id, expert_id
layer.mlp.experts.local_experts, layer_id
)
self.rules["local_experts.linear_fc2"](
expert.linear_fc2, layer_id, expert_id
layer.mlp.experts.local_experts, layer_id
)
else:
# For llama 4, in hf unified checkpoint, all local experts share one scale
self.rules["local_experts.linear_fc1"](
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does non-grouped MLP export still work even if you remove these lines?

layer.mlp.experts.local_experts, layer_id
)
self.rules["local_experts.linear_fc2"](
layer.mlp.experts.local_experts, layer_id
)
# GroupedMLP / TEGroupedMLP: experts are fused into
# linear_fc1 and linear_fc2 (TE GroupedLinear) with
# per-expert weights stored as weight0, weight1, ...
experts_module = layer.mlp.experts
num_experts = experts_module.num_local_experts
expert_config = experts_module.config
for expert_id in range(num_experts):
fc1_proxy = _GroupedLinearExpertProxy(
experts_module.linear_fc1, expert_id, expert_config
)
fc2_proxy = _GroupedLinearExpertProxy(
experts_module.linear_fc2, expert_id, expert_config
)
self.rules["local_experts.linear_fc1"](
fc1_proxy, layer_id, expert_id
)
self.rules["local_experts.linear_fc2"](
fc2_proxy, layer_id, expert_id
)
else:
self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id)
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id)
Expand Down Expand Up @@ -598,6 +728,14 @@ def _get_mtp_state_dict(self) -> dict[str, torch.Tensor]:
def _get_mamba_layer_state_dict(self, layer, layer_id):
if not isinstance(layer.norm, IdentityOp):
self.rules["norm"](layer.norm, layer_id)
elif (
"norm" in self.rules
and hasattr(layer, "mixer")
and hasattr(layer.mixer, "in_proj")
and hasattr(layer.mixer.in_proj, "layer_norm_weight")
):
# TE spec: norm is fused into TELayerNormColumnParallelLinear (in_proj)
self.rules["norm"](_FusedLayerNormProxy(layer.mixer.in_proj), layer_id)

self.rules["mixer_norm"](layer.mixer.norm, layer_id)
self.rules["A_log"](layer.mixer.A_log, layer_id)
Expand Down Expand Up @@ -695,13 +833,31 @@ def _get_eagle_module_state_dict(self):
self.rules["eagle_module.shared_experts.linear_fc2"](
layer.mlp.shared_experts.linear_fc2, layer_id
)
for expert_id, expert in enumerate(layer.mlp.experts.local_experts):
self.rules["eagle_module.local_experts.linear_fc1"](
expert.linear_fc1, layer_id, expert_id
)
self.rules["eagle_module.local_experts.linear_fc2"](
expert.linear_fc2, layer_id, expert_id
)
if hasattr(layer.mlp.experts, "local_experts"):
for expert_id, expert in enumerate(layer.mlp.experts.local_experts):
self.rules["eagle_module.local_experts.linear_fc1"](
expert.linear_fc1, layer_id, expert_id
)
self.rules["eagle_module.local_experts.linear_fc2"](
expert.linear_fc2, layer_id, expert_id
)
else:
experts_module = layer.mlp.experts
num_experts = experts_module.num_local_experts
expert_config = experts_module.config
for expert_id in range(num_experts):
fc1_proxy = _GroupedLinearExpertProxy(
experts_module.linear_fc1, expert_id, expert_config
)
fc2_proxy = _GroupedLinearExpertProxy(
experts_module.linear_fc2, expert_id, expert_config
)
self.rules["eagle_module.local_experts.linear_fc1"](
fc1_proxy, layer_id, expert_id
)
self.rules["eagle_module.local_experts.linear_fc2"](
fc2_proxy, layer_id, expert_id
)
else:
self.rules["linear_fc1"](layer.mlp.linear_fc1, layer_id)
self.rules["linear_fc2"](layer.mlp.linear_fc2, layer_id)
Expand Down
Loading