diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 7dec05338..d88fbd7a1 100755 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -23,6 +23,7 @@ NVIDIA Model Optimizer Changelog (Linux) - Add support for image-text data calibration in PTQ for Nemotron VL models. - Add PTQ support for Nemotron Parse. - Add distillation support for LTX-2. See `examples/diffusers/distillation/README.md `_ for more details. +- Add Megatron Core export/import mapping for Qwen3-VL (``Qwen3VLForConditionalGeneration``) vision-language models. The mapping handles the ``model.language_model.`` weight prefix used by Qwen3-VL. 0.41 (2026-01-19) ^^^^^^^^^^^^^^^^^ diff --git a/docs/source/deployment/3_unified_hf.rst b/docs/source/deployment/3_unified_hf.rst index 9124164b5..6664f987f 100644 --- a/docs/source/deployment/3_unified_hf.rst +++ b/docs/source/deployment/3_unified_hf.rst @@ -61,6 +61,7 @@ Models: * Llama 4, 3.x (FP8, NVFP4) * Qwen 3, 2.5 (FP8, NVFP4) * Qwen 3 MoE (FP8, NVFP4) + * Qwen 3-VL (FP8, NVFP4) * Deepseek R1/V3 (NVFP4) * Mixtral 8x7B (FP8, NVFP4) * Medusa (FP8) diff --git a/modelopt/torch/export/plugins/mcore_common.py b/modelopt/torch/export/plugins/mcore_common.py index d5bab9b4e..660e4eac9 100644 --- a/modelopt/torch/export/plugins/mcore_common.py +++ b/modelopt/torch/export/plugins/mcore_common.py @@ -39,6 +39,10 @@ qwen25_causal_lm_export, qwen25_causal_lm_import, ) +from .mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) all_mcore_hf_export_mapping: dict[str, Any] = { "DeepseekV2ForCausalLM": deepseek_causal_lm_export, @@ -54,6 +58,7 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_export, "Qwen2ForCausalLM": qwen25_causal_lm_export, "GptOssForCausalLM": gptoss_causal_lm_export, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_export, } all_mcore_hf_import_mapping: dict[str, Any] = { @@ -66,4 +71,5 @@ "Qwen3MoeForCausalLM": qwen3_causal_lm_import, "Qwen2ForCausalLM": qwen25_causal_lm_import, "GptOssForCausalLM": gptoss_causal_lm_import, + "Qwen3VLForConditionalGeneration": qwen3vl_causal_lm_import, } diff --git a/modelopt/torch/export/plugins/mcore_qwen3vl.py b/modelopt/torch/export/plugins/mcore_qwen3vl.py new file mode 100644 index 000000000..40eb99adb --- /dev/null +++ b/modelopt/torch/export/plugins/mcore_qwen3vl.py @@ -0,0 +1,120 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Custom mapping from Qwen3-VL Hugging Face models to Megatron Core models. + +Qwen3-VL model structure differs from Qwen3: +- Language model weights are under `model.language_model.` prefix +- Visual encoder weights are under `model.visual.` prefix + +This module handles the language model conversion for PTQ/QAT workflows. +Visual components are typically kept in full precision. + +HuggingFace Qwen3-VL-8B structure: +- model.language_model.embed_tokens.weight +- model.language_model.layers.{L}.input_layernorm.weight +- model.language_model.layers.{L}.self_attn.q_proj.weight +- model.language_model.layers.{L}.self_attn.k_proj.weight +- model.language_model.layers.{L}.self_attn.v_proj.weight +- model.language_model.layers.{L}.self_attn.q_norm.weight +- model.language_model.layers.{L}.self_attn.k_norm.weight +- model.language_model.layers.{L}.self_attn.o_proj.weight +- model.language_model.layers.{L}.post_attention_layernorm.weight +- model.language_model.layers.{L}.mlp.gate_proj.weight +- model.language_model.layers.{L}.mlp.up_proj.weight +- model.language_model.layers.{L}.mlp.down_proj.weight +- model.language_model.norm.weight +- lm_head.weight +""" + +from .mcore_custom import ( + COL_ETP, + COL_TP, + REPLICATE, + ROW_ETP, + ROW_TP, + CustomModuleMapping, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) + +# Import rules: HuggingFace -> Megatron Core +qwen3vl_causal_lm_import: dict[str, CustomModuleMapping] = { + # Embeddings - note the language_model prefix + "word_embeddings": NameRemapping("model.language_model.embed_tokens.", COL_TP), + # Final layer norm + "final_layernorm": NameRemapping("model.language_model.norm.", REPLICATE), + # Output layer (lm_head is at root level, not under language_model) + "output_layer": NameRemapping("lm_head.", COL_TP), + # Attention - input layernorm + "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm.", REPLICATE), + # Attention - QKV projection (merged) + "linear_qkv": QKVMerging("model.language_model.layers.{}.self_attn.", COL_TP), + # Attention - output projection + "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj.", ROW_TP), + # Attention - Q/K layer norms (Qwen3 uses RMSNorm on Q and K) + "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm.", REPLICATE), + "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm.", REPLICATE), + # MLP - pre-MLP layernorm (post_attention_layernorm in HF) + "pre_mlp_layernorm": NameRemapping( + "model.language_model.layers.{}.post_attention_layernorm.", REPLICATE + ), + # MLP - gate_proj + up_proj merged into linear_fc1 + "linear_fc1": GatedMLPMerging("model.language_model.layers.{}.mlp.", COL_TP), + # MLP - down_proj as linear_fc2 + "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj.", ROW_TP), + # MoE support (for Qwen3-VL MoE variants like 30B-A3B) + "router": NameRemapping("model.language_model.layers.{}.mlp.gate.", REPLICATE), + "local_experts.linear_fc1": GatedMLPMerging( + "model.language_model.layers.{}.mlp.experts.{}.", COL_ETP + ), + "local_experts.linear_fc2": NameRemapping( + "model.language_model.layers.{}.mlp.experts.{}.down_proj.", ROW_ETP + ), +} + +# Export rules: Megatron Core -> HuggingFace +qwen3vl_causal_lm_export: dict[str, CustomModuleMapping] = { + # Embeddings + "word_embeddings": NameRemapping("model.language_model.embed_tokens."), + # Final layer norm + "final_layernorm": NameRemapping("model.language_model.norm."), + # Output layer + "output_layer": NameRemapping("lm_head."), + # Attention - input layernorm + "input_layernorm": NameRemapping("model.language_model.layers.{}.input_layernorm."), + # Attention - QKV projection (sliced back to separate q/k/v) + "linear_qkv": QKVSlicing("model.language_model.layers.{}.self_attn."), + # Attention - output projection + "linear_proj": NameRemapping("model.language_model.layers.{}.self_attn.o_proj."), + # Attention - Q/K layer norms + "q_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.q_norm."), + "k_layernorm": NameRemapping("model.language_model.layers.{}.self_attn.k_norm."), + # MLP - pre-MLP layernorm + "pre_mlp_layernorm": NameRemapping("model.language_model.layers.{}.post_attention_layernorm."), + # MLP - linear_fc1 sliced back to gate_proj + up_proj + "linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp."), + # MLP - down_proj + "linear_fc2": NameRemapping("model.language_model.layers.{}.mlp.down_proj."), + # MoE support + "router": NameRemapping("model.language_model.layers.{}.mlp.gate."), + "local_experts.linear_fc1": GatedMLPSlicing("model.language_model.layers.{}.mlp.experts.{}."), + "local_experts.linear_fc2": NameRemapping( + "model.language_model.layers.{}.mlp.experts.{}.down_proj." + ), +} \ No newline at end of file diff --git a/tests/unit/torch/export/test_mcore_qwen3vl.py b/tests/unit/torch/export/test_mcore_qwen3vl.py new file mode 100644 index 000000000..3f57cb9c4 --- /dev/null +++ b/tests/unit/torch/export/test_mcore_qwen3vl.py @@ -0,0 +1,306 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for Qwen3-VL Megatron Core export/import plugin.""" + +import pytest + +from modelopt.torch.export.plugins.mcore_custom import ( + COL_TP, + REPLICATE, + ROW_TP, + GatedMLPMerging, + GatedMLPSlicing, + NameRemapping, + QKVMerging, + QKVSlicing, +) +from modelopt.torch.export.plugins.mcore_qwen3vl import ( + qwen3vl_causal_lm_export, + qwen3vl_causal_lm_import, +) + + +# All mcore keys that a dense (non-MoE) Qwen3-VL model should have +DENSE_MCORE_KEYS = { + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", +} + +# Additional MoE keys +MOE_MCORE_KEYS = { + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", +} + + +class TestQwen3VLRegistration: + """Test that Qwen3-VL is registered in the global mapping.""" + + def test_registered_in_export_mapping(self): + from modelopt.torch.export.plugins.mcore_common import ( + all_mcore_hf_export_mapping, + ) + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_export_mapping + assert ( + all_mcore_hf_export_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_export + ) + + def test_registered_in_import_mapping(self): + from modelopt.torch.export.plugins.mcore_common import ( + all_mcore_hf_import_mapping, + ) + + assert "Qwen3VLForConditionalGeneration" in all_mcore_hf_import_mapping + assert ( + all_mcore_hf_import_mapping["Qwen3VLForConditionalGeneration"] + is qwen3vl_causal_lm_import + ) + + +class TestQwen3VLImportMapping: + """Test the HuggingFace -> Megatron Core import mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_import.keys()) + + def test_language_model_prefix(self): + """Qwen3-VL uses model.language_model. prefix (not model.).""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_import[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + """lm_head is at root level, not under language_model.""" + mapping = qwen3vl_causal_lm_import["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_merging(self): + assert isinstance(qwen3vl_causal_lm_import["linear_qkv"], QKVMerging) + + def test_mlp_uses_gated_merging(self): + assert isinstance( + qwen3vl_causal_lm_import["linear_fc1"], GatedMLPMerging + ) + + @pytest.mark.parametrize( + "key", + [ + "input_layernorm", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "final_layernorm", + ], + ) + def test_layernorms_are_replicated(self, key): + """Layernorms should use REPLICATE (empty func_kwargs).""" + mapping = qwen3vl_causal_lm_import[key] + assert isinstance(mapping, NameRemapping) + assert mapping.func_kwargs == REPLICATE + + @pytest.mark.parametrize( + "key,expected_kwargs", + [ + ("word_embeddings", COL_TP), + ("output_layer", COL_TP), + ("linear_proj", ROW_TP), + ], + ) + def test_tp_sharding(self, key, expected_kwargs): + mapping = qwen3vl_causal_lm_import[key] + assert mapping.func_kwargs == expected_kwargs + + +class TestQwen3VLExportMapping: + """Test the Megatron Core -> HuggingFace export mapping.""" + + def test_has_all_dense_keys(self): + assert DENSE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_has_all_moe_keys(self): + assert MOE_MCORE_KEYS.issubset(qwen3vl_causal_lm_export.keys()) + + def test_language_model_prefix(self): + """Export paths should also use model.language_model. prefix.""" + prefix_keys = [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + ] + for key in prefix_keys: + mapping = qwen3vl_causal_lm_export[key] + assert "model.language_model." in mapping.target_name_or_prefix, ( + f"{key}: expected 'model.language_model.' prefix, " + f"got '{mapping.target_name_or_prefix}'" + ) + + def test_output_layer_at_root(self): + mapping = qwen3vl_causal_lm_export["output_layer"] + assert mapping.target_name_or_prefix == "lm_head." + + def test_qkv_uses_slicing(self): + assert isinstance(qwen3vl_causal_lm_export["linear_qkv"], QKVSlicing) + + def test_mlp_uses_gated_slicing(self): + assert isinstance( + qwen3vl_causal_lm_export["linear_fc1"], GatedMLPSlicing + ) + + def test_export_has_no_parallel_config(self): + """Export mappings should not have parallel configs.""" + for key in ["word_embeddings", "final_layernorm", "output_layer", + "input_layernorm", "linear_proj"]: + mapping = qwen3vl_causal_lm_export[key] + assert "parallel_config" not in mapping.func_kwargs + + +class TestQwen3VLImportExportSymmetry: + """Test that import and export mappings are consistent.""" + + def test_same_mcore_keys(self): + assert set(qwen3vl_causal_lm_import.keys()) == set( + qwen3vl_causal_lm_export.keys() + ) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "output_layer", + "input_layernorm", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc2", + "router", + ], + ) + def test_matching_hf_prefixes(self, key): + """Import and export should map to the same HF prefix.""" + imp = qwen3vl_causal_lm_import[key] + exp = qwen3vl_causal_lm_export[key] + assert imp.target_name_or_prefix == exp.target_name_or_prefix, ( + f"{key}: import prefix '{imp.target_name_or_prefix}' != " + f"export prefix '{exp.target_name_or_prefix}'" + ) + + def test_qkv_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_qkv"] + exp = qwen3vl_causal_lm_export["linear_qkv"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + def test_mlp_fc1_matching_prefix(self): + imp = qwen3vl_causal_lm_import["linear_fc1"] + exp = qwen3vl_causal_lm_export["linear_fc1"] + assert imp.target_name_or_prefix == exp.target_name_or_prefix + + +class TestQwen3VLvsQwen3Difference: + """Test that Qwen3-VL differs from Qwen3 only in the language_model prefix.""" + + def test_same_keys_as_qwen3(self): + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_export, + qwen3_causal_lm_import, + ) + + assert set(qwen3vl_causal_lm_import.keys()) == set( + qwen3_causal_lm_import.keys() + ) + assert set(qwen3vl_causal_lm_export.keys()) == set( + qwen3_causal_lm_export.keys() + ) + + @pytest.mark.parametrize( + "key", + [ + "word_embeddings", + "final_layernorm", + "input_layernorm", + "linear_qkv", + "linear_proj", + "q_layernorm", + "k_layernorm", + "pre_mlp_layernorm", + "linear_fc1", + "linear_fc2", + "router", + "local_experts.linear_fc1", + "local_experts.linear_fc2", + ], + ) + def test_vl_adds_language_model_prefix(self, key): + """Qwen3-VL should have 'language_model.' inserted after 'model.'.""" + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_import, + ) + + qwen3_prefix = qwen3_causal_lm_import[key].target_name_or_prefix + qwen3vl_prefix = qwen3vl_causal_lm_import[key].target_name_or_prefix + expected = qwen3_prefix.replace("model.", "model.language_model.", 1) + assert qwen3vl_prefix == expected, ( + f"{key}: expected '{expected}', got '{qwen3vl_prefix}'" + ) + + def test_output_layer_same(self): + """lm_head is at root level for both Qwen3 and Qwen3-VL.""" + from modelopt.torch.export.plugins.mcore_qwen import ( + qwen3_causal_lm_import, + ) + + assert ( + qwen3vl_causal_lm_import["output_layer"].target_name_or_prefix + == qwen3_causal_lm_import["output_layer"].target_name_or_prefix + )