diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index d9a6ca893..7414771a9 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -14,7 +14,9 @@ # limitations under the License. import argparse +import os import random +import sys import time import warnings from typing import Any @@ -69,6 +71,9 @@ from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader +sys.path.append(os.path.join(os.path.dirname(__file__), "../speculative_decoding")) +from eagle_utils import make_eagle_supervised_data_module + RAND_SEED = 1234 QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = { @@ -316,7 +321,12 @@ def forward_step(model, batch): def load_model(args: argparse.Namespace): # If low memory mode is enabled, we compress the model while loading the HF checkpoint. calibration_only = False - if not args.low_memory_mode: + if args.specdec_offline_dataset is not None: + full_model = AutoModelForCausalLM.from_pretrained( + args.pyt_ckpt_path, + trust_remote_code=args.trust_remote_code, + ) + elif not args.low_memory_mode: full_model = get_model( args.pyt_ckpt_path, args.device, @@ -402,15 +412,27 @@ def load_model(args: argparse.Namespace): language_model = extracted_lm model_type = extracted_model_type else: - if args.dataset is None: - args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] - warnings.warn( - "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." + if args.specdec_offline_dataset is not None: + language_model = full_model + else: + if args.dataset is None: + args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"] + warnings.warn( + "No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2." + ) + # Adjust calib_size to match dataset length by extending or truncating as needed + args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ + : len(args.dataset) + ] + + # We only quantize the language model for VLMs other than the type supported above. + extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl( + full_model ) - # Adjust calib_size to match dataset length by extending or truncating as needed - args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[ - : len(args.dataset) - ] + if extracted_lm is not None: + language_model = extracted_lm + model_type = extracted_model_type + tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code) default_padding_side = tokenizer.padding_side @@ -418,12 +440,6 @@ def load_model(args: argparse.Namespace): # Left padding usually provides better calibration result. tokenizer.padding_side = "left" - # We only quantize the language model for VLMs other than the type supported above. - extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model) - if extracted_lm is not None: - language_model = extracted_lm - model_type = extracted_model_type - if model_type == "phi4mm": warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.") @@ -548,6 +564,7 @@ def export_quantized( tokenizer: PreTrainedTokenizerBase | None, default_padding_side, default_pad_token, + offline_specdec_input: dict | None = None, ): with torch.inference_mode(): if model_type is None: @@ -637,6 +654,7 @@ def export_quantized( full_model, export_dir=export_path, extra_state_dict=mtp_state_dict, + offline_specdec_input=offline_specdec_input, ) # Copy custom model files (Python files and JSON configs) if trust_remote_code is used @@ -800,49 +818,73 @@ def quantize_main( device: torch.device, ): if args.batch_size == 0: - # Calibration/sparsification will actually take much more memory than regular inference - # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio - # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. - sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 - # Whisper model expects mel-spectrogram input features of length 3000 - # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) - # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float - # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() - if model_type == "whisper": - max_sample_length = 3000 - num_mel_bins = language_model.config.num_mel_bins - sample_input_single_batch = ( - torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( - language_model.device - ) - * 100 - ) + if args.specdec_offline_dataset is not None: + # Speculative decoding offline model dost not support get_max_batch_size() because of + # the customized dataloader, so we set batch_size to 1 to avoid OOM. + args.batch_size = 1 else: - sample_input_single_batch = None + # Calibration/sparsification will actually take much more memory than regular inference + # due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio + # to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference. + sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1 + # Whisper model expects mel-spectrogram input features of length 3000 + # Whisper model needs input of shape (batch_size, num_mel_bins, 3000) + # As the encoder of Whisper doesn't have embedding layer, input dtype has to be float + # For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size() + if model_type == "whisper": + max_sample_length = 3000 + num_mel_bins = language_model.config.num_mel_bins + sample_input_single_batch = ( + torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to( + language_model.device + ) + * 100 + ) + else: + sample_input_single_batch = None - run_auto_quant = args.auto_quantize_bits is not None + run_auto_quant = args.auto_quantize_bits is not None - args.batch_size = get_max_batch_size( - language_model, - max_sample_length=args.calib_seq, - sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, - sample_input_single_batch=sample_input_single_batch, - enable_grad=run_auto_quant, - ) - args.batch_size = min(args.batch_size, sum(args.calib_size)) + args.batch_size = get_max_batch_size( + language_model, + max_sample_length=args.calib_seq, + sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0, + sample_input_single_batch=sample_input_single_batch, + enable_grad=run_auto_quant, + ) + args.batch_size = min(args.batch_size, sum(args.calib_size)) print(f"Use calib batch_size {args.batch_size}") - calib_dataloader, first_text_speech_dataset = make_calib_dataloader( - args, language_model, processor, tokenizer, device, model_type - ) + if args.specdec_offline_dataset is not None: + data_args = argparse.Namespace( + vlm_processor=None, + vlm_img_dir=None, + data_path=args.specdec_offline_dataset, + offline_data_path=args.specdec_offline_feature, + devlazy_preprocessice=True, + ) + data_module = make_eagle_supervised_data_module( + tokenizer, data_args, max_length=args.calib_seq + ) + calib_dataloader = DataLoader( + data_module["eval_dataset"], + batch_size=args.batch_size, + shuffle=False, + collate_fn=data_module["data_collator"], + ) + else: + calib_dataloader, first_text_speech_dataset = make_calib_dataloader( + args, language_model, processor, tokenizer, device, model_type + ) # Detect if this is a Nemotron VL model using architecture-based detection is_nemotron_vl_model = is_nemotron_vl(full_model) - preview_input_ids, generated_ids_before_ptq = pre_quantize( - args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model - ) + if args.specdec_offline_dataset is None: + preview_input_ids, generated_ids_before_ptq = pre_quantize( + args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model + ) if args.auto_quantize_bits: assert len(args.qformat.split(",")) > 1, ( @@ -915,17 +957,18 @@ def quantize_main( assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton" print(f"qformat: {args.qformat}. No quantization applied, export {device} model") - post_quantize( - args, - full_model, - model_type, - tokenizer, - processor, - preview_input_ids, - generated_ids_before_ptq, - is_nemotron_vl_model, - first_text_speech_dataset, - ) + if args.specdec_offline_dataset is None: + post_quantize( + args, + full_model, + model_type, + tokenizer, + processor, + preview_input_ids, + generated_ids_before_ptq, + is_nemotron_vl_model, + first_text_speech_dataset, + ) export_quantized( args, full_model, @@ -934,6 +977,9 @@ def quantize_main( tokenizer, default_padding_side, default_pad_token, + offline_specdec_input=next(iter(calib_dataloader), None) + if args.specdec_offline_dataset is not None + else None, ) @@ -985,6 +1031,23 @@ def parse_args() -> argparse.Namespace: type=str, default=None, ) + parser.add_argument( + "--specdec_offline_feature", + help=( + "If set, the model is a speculative decoding model," + "which uses offline dataset for calibration. " + ), + default=None, + ) + parser.add_argument( + "--specdec_offline_dataset", + help=( + "Path to the offline dataset for speculative decoding model calibration. " + "This should be a JSON or JSONL file or a directory with JSON or JSONL files " + "containing the calibration samples. " + ), + default=None, + ) parser.add_argument( "--calib_with_images", action="store_true", diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 8706ca049..30bbfb77f 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -184,6 +184,10 @@ def train(): model_args.model_name_or_path, trust_remote_code=True ) model.config.num_orig_hidden_layers = model_config.num_hidden_layers + if hasattr(model.config, "layer_types"): + del ( + model.config.layer_types + ) # remove layer_types to avoid mismatch with the modified model tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 5703f4515..586d37863 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -269,7 +269,9 @@ def _fuse_shared_input_modules( return fused_linears -def requantize_resmooth_fused_llm_layers(model: torch.nn.Module): +def requantize_resmooth_fused_llm_layers( + model: torch.nn.Module, offline_specdec_input: dict | None = None +): """Group modules that take the same input and register shared parameters in module.""" # TODO: Handle DBRX MoE quantization_format = get_quantization_format(model) @@ -337,6 +339,9 @@ def llm_dummy_forward(): "This is required for requantization/resmoothing optimization. " "Please ensure the model architecture is supported or file an issue." ) + elif offline_specdec_input is not None: + # For offline SpecDec models, we need to pass the specific input format used during training + model(**offline_specdec_input) else: model(fake_input) @@ -698,7 +703,9 @@ def _export_transformers_checkpoint( # Resmooth and requantize fused layers # TODO: Handle mixed precision - requantize_resmooth_fused_llm_layers(model) + requantize_resmooth_fused_llm_layers( + model, offline_specdec_input=kwargs.get("offline_specdec_input") + ) # Remove all hooks from the model try: @@ -961,6 +968,7 @@ def export_hf_checkpoint( save_modelopt_state: bool = False, components: list[str] | None = None, extra_state_dict: dict[str, torch.Tensor] | None = None, + offline_specdec_input: dict | None = None, ): """Export quantized HuggingFace model checkpoint (transformers or diffusers). @@ -999,7 +1007,9 @@ def export_hf_checkpoint( return try: - post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype) + post_state_dict, hf_quant_config = _export_transformers_checkpoint( + model, dtype, offline_specdec_input=offline_specdec_input + ) if hf_quant_config is not None: # Save hf_quant_config.json for backward compatibility diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 16bff49c2..093d05eb2 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -382,8 +382,11 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None): Returns: The maximum batch size that worked successfully """ - assert all(torch.is_tensor(data) or data is None for data in batch_data.values()), ( - "batch_data values must be tensors" + assert all( + torch.is_tensor(data) or data is None or key == "base_model_outputs" + for key, data in batch_data.items() + ), ( + "batch_data values must be tensors or None, except for 'base_model_outputs' which can be any type." ) # Get the batch size of current data batch_size = batch_data[next(iter(batch_data.keys()))].shape[0]