Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 122 additions & 59 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
# limitations under the License.

import argparse
import os
import random
import sys
import time
import warnings
from typing import Any
Expand Down Expand Up @@ -69,6 +71,9 @@
from modelopt.torch.utils.speech_dataset_utils import get_speech_dataset_dataloader
from modelopt.torch.utils.vlm_dataset_utils import get_vlm_dataset_dataloader

sys.path.append(os.path.join(os.path.dirname(__file__), "../speculative_decoding"))
from eagle_utils import make_eagle_supervised_data_module

RAND_SEED = 1234

QUANT_CFG_CHOICES: dict[str, dict[str, Any]] = {
Expand Down Expand Up @@ -316,7 +321,12 @@ def forward_step(model, batch):
def load_model(args: argparse.Namespace):
# If low memory mode is enabled, we compress the model while loading the HF checkpoint.
calibration_only = False
if not args.low_memory_mode:
if args.specdec_offline_dataset is not None:
full_model = AutoModelForCausalLM.from_pretrained(
args.pyt_ckpt_path,
trust_remote_code=args.trust_remote_code,
)
elif not args.low_memory_mode:
full_model = get_model(
args.pyt_ckpt_path,
args.device,
Expand Down Expand Up @@ -402,28 +412,34 @@ def load_model(args: argparse.Namespace):
language_model = extracted_lm
model_type = extracted_model_type
else:
if args.dataset is None:
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
warnings.warn(
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
if args.specdec_offline_dataset is not None:
language_model = full_model
else:
if args.dataset is None:
args.dataset = ["cnn_dailymail", "nemotron-post-training-dataset-v2"]
warnings.warn(
"No dataset specified. Defaulting to cnn_dailymail and nemotron-post-training-dataset-v2."
)
# Adjust calib_size to match dataset length by extending or truncating as needed
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
: len(args.dataset)
]

# We only quantize the language model for VLMs other than the type supported above.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(
full_model
)
# Adjust calib_size to match dataset length by extending or truncating as needed
args.calib_size = (args.calib_size + [args.calib_size[-1]] * len(args.dataset))[
: len(args.dataset)
]
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type

tokenizer = get_tokenizer(args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code)

default_padding_side = tokenizer.padding_side
default_pad_token = tokenizer.pad_token
# Left padding usually provides better calibration result.
tokenizer.padding_side = "left"

# We only quantize the language model for VLMs other than the type supported above.
extracted_lm, extracted_model_type = extract_and_prepare_language_model_from_vl(full_model)
if extracted_lm is not None:
language_model = extracted_lm
model_type = extracted_model_type

if model_type == "phi4mm":
warnings.warn("Please set the default input_mode to InputMode.LANGUAGE before quantizing.")

Expand Down Expand Up @@ -548,6 +564,7 @@ def export_quantized(
tokenizer: PreTrainedTokenizerBase | None,
default_padding_side,
default_pad_token,
offline_specdec_input: dict | None = None,
):
with torch.inference_mode():
if model_type is None:
Expand Down Expand Up @@ -637,6 +654,7 @@ def export_quantized(
full_model,
export_dir=export_path,
extra_state_dict=mtp_state_dict,
offline_specdec_input=offline_specdec_input,
)

# Copy custom model files (Python files and JSON configs) if trust_remote_code is used
Expand Down Expand Up @@ -800,49 +818,73 @@ def quantize_main(
device: torch.device,
):
if args.batch_size == 0:
# Calibration/sparsification will actually take much more memory than regular inference
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
# Whisper model expects mel-spectrogram input features of length 3000
# Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
# As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
# For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
if model_type == "whisper":
max_sample_length = 3000
num_mel_bins = language_model.config.num_mel_bins
sample_input_single_batch = (
torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
language_model.device
)
* 100
)
if args.specdec_offline_dataset is not None:
# Speculative decoding offline model dost not support get_max_batch_size() because of
# the customized dataloader, so we set batch_size to 1 to avoid OOM.
args.batch_size = 1
else:
sample_input_single_batch = None
# Calibration/sparsification will actually take much more memory than regular inference
# due to intermediate tensors for fake quantization. Setting sample_memory_usage_ratio
# to 2 to avoid OOM for AWQ/SmoothQuant fake quantization as it will take more memory than inference.
sample_memory_usage_ratio = 2 if "awq" in args.qformat or "sq" in args.qformat else 1.1
# Whisper model expects mel-spectrogram input features of length 3000
# Whisper model needs input of shape (batch_size, num_mel_bins, 3000)
# As the encoder of Whisper doesn't have embedding layer, input dtype has to be float
# For non-Whisper models (language models), sample_input will be set up inside get_max_batch_size()
if model_type == "whisper":
max_sample_length = 3000
num_mel_bins = language_model.config.num_mel_bins
sample_input_single_batch = (
torch.ones([1, num_mel_bins, max_sample_length], dtype=language_model.dtype).to(
language_model.device
)
* 100
)
else:
sample_input_single_batch = None

run_auto_quant = args.auto_quantize_bits is not None
run_auto_quant = args.auto_quantize_bits is not None

args.batch_size = get_max_batch_size(
language_model,
max_sample_length=args.calib_seq,
sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
sample_input_single_batch=sample_input_single_batch,
enable_grad=run_auto_quant,
)
args.batch_size = min(args.batch_size, sum(args.calib_size))
args.batch_size = get_max_batch_size(
language_model,
max_sample_length=args.calib_seq,
sample_memory_usage_ratio=sample_memory_usage_ratio if not run_auto_quant else 1.0,
sample_input_single_batch=sample_input_single_batch,
enable_grad=run_auto_quant,
)
args.batch_size = min(args.batch_size, sum(args.calib_size))

print(f"Use calib batch_size {args.batch_size}")

calib_dataloader, first_text_speech_dataset = make_calib_dataloader(
args, language_model, processor, tokenizer, device, model_type
)
if args.specdec_offline_dataset is not None:
data_args = argparse.Namespace(
vlm_processor=None,
vlm_img_dir=None,
data_path=args.specdec_offline_dataset,
offline_data_path=args.specdec_offline_feature,
devlazy_preprocessice=True,
)
data_module = make_eagle_supervised_data_module(
tokenizer, data_args, max_length=args.calib_seq
)
calib_dataloader = DataLoader(
data_module["eval_dataset"],
batch_size=args.batch_size,
shuffle=False,
collate_fn=data_module["data_collator"],
)
else:
calib_dataloader, first_text_speech_dataset = make_calib_dataloader(
args, language_model, processor, tokenizer, device, model_type
)

# Detect if this is a Nemotron VL model using architecture-based detection
is_nemotron_vl_model = is_nemotron_vl(full_model)

preview_input_ids, generated_ids_before_ptq = pre_quantize(
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
)
if args.specdec_offline_dataset is None:
preview_input_ids, generated_ids_before_ptq = pre_quantize(
args, full_model, model_type, tokenizer, calib_dataloader, is_nemotron_vl_model
)

if args.auto_quantize_bits:
assert len(args.qformat.split(",")) > 1, (
Expand Down Expand Up @@ -915,17 +957,18 @@ def quantize_main(
assert model_type != "dbrx", f"Does not support export {model_type} without quantizaton"
print(f"qformat: {args.qformat}. No quantization applied, export {device} model")

post_quantize(
args,
full_model,
model_type,
tokenizer,
processor,
preview_input_ids,
generated_ids_before_ptq,
is_nemotron_vl_model,
first_text_speech_dataset,
)
if args.specdec_offline_dataset is None:
post_quantize(
args,
full_model,
model_type,
tokenizer,
processor,
preview_input_ids,
generated_ids_before_ptq,
is_nemotron_vl_model,
first_text_speech_dataset,
)
export_quantized(
args,
full_model,
Expand All @@ -934,6 +977,9 @@ def quantize_main(
tokenizer,
default_padding_side,
default_pad_token,
offline_specdec_input=next(iter(calib_dataloader), None)
if args.specdec_offline_dataset is not None
else None,
)


Expand Down Expand Up @@ -985,6 +1031,23 @@ def parse_args() -> argparse.Namespace:
type=str,
default=None,
)
parser.add_argument(
"--specdec_offline_feature",
help=(
"If set, the model is a speculative decoding model,"
"which uses offline dataset for calibration. "
),
default=None,
)
parser.add_argument(
"--specdec_offline_dataset",
help=(
"Path to the offline dataset for speculative decoding model calibration. "
"This should be a JSON or JSONL file or a directory with JSON or JSONL files "
"containing the calibration samples. "
),
default=None,
)
parser.add_argument(
"--calib_with_images",
action="store_true",
Expand Down
4 changes: 4 additions & 0 deletions examples/speculative_decoding/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def train():
model_args.model_name_or_path, trust_remote_code=True
)
model.config.num_orig_hidden_layers = model_config.num_hidden_layers
if hasattr(model.config, "layer_types"):
del (
model.config.layer_types
) # remove layer_types to avoid mismatch with the modified model
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
model_max_length=training_args.training_seq_len,
Expand Down
16 changes: 13 additions & 3 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,9 @@ def _fuse_shared_input_modules(
return fused_linears


def requantize_resmooth_fused_llm_layers(model: torch.nn.Module):
def requantize_resmooth_fused_llm_layers(
model: torch.nn.Module, offline_specdec_input: dict | None = None
):
"""Group modules that take the same input and register shared parameters in module."""
# TODO: Handle DBRX MoE
quantization_format = get_quantization_format(model)
Expand Down Expand Up @@ -337,6 +339,9 @@ def llm_dummy_forward():
"This is required for requantization/resmoothing optimization. "
"Please ensure the model architecture is supported or file an issue."
)
elif offline_specdec_input is not None:
# For offline SpecDec models, we need to pass the specific input format used during training
model(**offline_specdec_input)
else:
model(fake_input)

Expand Down Expand Up @@ -698,7 +703,9 @@ def _export_transformers_checkpoint(

# Resmooth and requantize fused layers
# TODO: Handle mixed precision
requantize_resmooth_fused_llm_layers(model)
requantize_resmooth_fused_llm_layers(
model, offline_specdec_input=kwargs.get("offline_specdec_input")
)

# Remove all hooks from the model
try:
Expand Down Expand Up @@ -961,6 +968,7 @@ def export_hf_checkpoint(
save_modelopt_state: bool = False,
components: list[str] | None = None,
extra_state_dict: dict[str, torch.Tensor] | None = None,
offline_specdec_input: dict | None = None,
):
"""Export quantized HuggingFace model checkpoint (transformers or diffusers).

Expand Down Expand Up @@ -999,7 +1007,9 @@ def export_hf_checkpoint(
return

try:
post_state_dict, hf_quant_config = _export_transformers_checkpoint(model, dtype)
post_state_dict, hf_quant_config = _export_transformers_checkpoint(
model, dtype, offline_specdec_input=offline_specdec_input
)

if hf_quant_config is not None:
# Save hf_quant_config.json for backward compatibility
Expand Down
7 changes: 5 additions & 2 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,11 @@ def _process_batch(batch_data, infer_method, max_working_batch_size=None):
Returns:
The maximum batch size that worked successfully
"""
assert all(torch.is_tensor(data) or data is None for data in batch_data.values()), (
"batch_data values must be tensors"
assert all(
torch.is_tensor(data) or data is None or key == "base_model_outputs"
for key, data in batch_data.items()
), (
"batch_data values must be tensors or None, except for 'base_model_outputs' which can be any type."
)
# Get the batch size of current data
batch_size = batch_data[next(iter(batch_data.keys()))].shape[0]
Expand Down