diff --git a/examples/speculative_decoding/README.md b/examples/speculative_decoding/README.md index 3294ba653..9695a21e7 100644 --- a/examples/speculative_decoding/README.md +++ b/examples/speculative_decoding/README.md @@ -242,6 +242,17 @@ To add a system prompt, use the `--system_prompt ` argument. For large scale data generation, please see [SLURM prepare data](SLURM_prepare_data.md) for SLURM support. +### Configuring Draft Model + +For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. E.g. To use 2-layer eagle with 8192 intermediate size for MLP, set `eagle_config.json` to: + +```json +{ + "num_hidden_layers": 2, + "intermediate_size":8192 +} +``` + ### Draft Vocabulary Compression We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: @@ -252,15 +263,7 @@ python scripts/calibrate_draft_vocab.py --model meta-llama/Llama-3.2-1B-Instruct This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft token to target token. During inference, draft tokens can be mapped back to target tokens by `target_token = draft_token + d2t[draft_token]`. -### Configuring Draft Model - -For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](https://github.com/NVIDIA/Model-Optimizer/blob/main/modelopt/torch/speculative/config.py#L37) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`: - -```json -{ - "draft_vocab_size": 32000 -} -``` +Then, simply include the `--draft_vocab_cache ` argument when starting training with `./launch_train.sh`. The draft model will use this provided vocab table during training and export. ### Interact with `modelopt.torch.speculative` diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh index 48d12aeb2..debbe6881 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens.sh @@ -19,5 +19,5 @@ python3 collect_hidden_states/compute_hidden_states_hf.py \ --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ + --input-data synthetic_conversations/daring-anteater.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh index 31e2294d9..dac0ab9a9 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_hf_compute_hiddens_dp.sh @@ -30,7 +30,7 @@ split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FI for i in $(seq 0 $((DP_SIZE-1))) do -CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & +CUDA_VISIBLE_DEVICES=$i python3 collect_hidden_states/compute_hidden_states_hf.py --model meta-llama/Llama-3.2-1B-Instruct --input-data /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR & done wait diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh index 487d0d69d..75a27deb6 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens.sh @@ -20,6 +20,6 @@ export TLLM_LOG_LEVEL="error"; python3 collect_hidden_states/compute_hidden_states_trtllm.py \ --model meta-llama/Llama-3.2-1B-Instruct \ - --input-file synthetic_conversations/daring-anteater.jsonl \ + --input-data synthetic_conversations/daring-anteater.jsonl \ --output-dir /mnt/md0/eagle-hidden-states/llama1b/daring_anteater/ \ No newline at end of file diff --git a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh index 4b0fd1060..d06cfc061 100644 --- a/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh +++ b/examples/speculative_decoding/collect_hidden_states/run_trtllm_compute_hiddens_dp.sh @@ -33,7 +33,7 @@ split -n l/$DP_SIZE --numeric-suffixes=0 -d --additional-suffix=.jsonl $INPUT_FI for i in $(seq 0 $((DP_SIZE-1))) do -export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-file /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & +export CUDA_VISIBLE_DEVICES=$i; python3 collect_hidden_states/compute_hidden_states_trtllm.py --model $MODEL --input-data /tmp/part-0${i}.jsonl --output-dir $OUTPUT_DIR --dp-rank $i & done wait diff --git a/examples/speculative_decoding/eagle_utils.py b/examples/speculative_decoding/eagle_utils.py index 3625072b1..3ef715637 100644 --- a/examples/speculative_decoding/eagle_utils.py +++ b/examples/speculative_decoding/eagle_utils.py @@ -14,8 +14,6 @@ # limitations under the License. import inspect -import json -import os from collections.abc import Callable from pathlib import Path from typing import TYPE_CHECKING @@ -29,16 +27,20 @@ import transformers from datasets import load_dataset from packaging.version import Version -from PIL import Image from scripts.ar_validate import validate_ar from torch.utils.data import Dataset -from transformers import AutoProcessor, Trainer, TrainerCallback +from transformers import Trainer, TrainerCallback from transformers.trainer_pt_utils import LabelSmoother import modelopt from modelopt.torch.speculative.utils import get_ttt_msk_func from modelopt.torch.utils import print_rank_0 from modelopt.torch.utils.distributed import is_master +from modelopt.torch.utils.plugins.transformers_dataset import ( + LanguageDataCollator, + ShardedDataset, + VisionLanguageDataCollator, +) try: import wandb @@ -47,459 +49,124 @@ IGNORE_TOKEN_ID = LabelSmoother.ignore_index -REMOVE_THINK_CHAT_TEMPLATE = ( - "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}" -) - - -def preprocess(examples, tokenizer, **kwargs): - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") - new_examples = { - "input_ids": [], - "attention_mask": [], - "loss_mask": [], - "labels": [], - } - for i in range(len(examples)): - messages = [] - source = examples[i]["conversations"] - - # Detect format: either role/content or from/value - def get_role_content(item): - if "role" in item and "content" in item: - return item["role"], item["content"] - elif "from" in item and "value" in item: - return item["from"], item["value"] - else: - raise ValueError(f"Unknown conversation format: {item}") - - for sentence in source: - role, content = get_role_content(sentence) - messages.append({"role": role.lower(), "content": content}) - conversation = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ) - - output = tokenizer( - conversation, - return_tensors="pt", - add_special_tokens=False, - truncation=True, - ) - input_ids = output.input_ids[0] - attention_mask = output.attention_mask[0] - loss_mask = torch.ones_like(input_ids) - labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)]) - new_examples["input_ids"].append(input_ids) - new_examples["attention_mask"].append(attention_mask) - new_examples["loss_mask"].append(loss_mask) - new_examples["labels"].append(labels) - - return new_examples - - -def preprocess_vlm(examples, tokenizer, processor, img_dir): - tokenizer.chat_template = tokenizer.chat_template.replace(REMOVE_THINK_CHAT_TEMPLATE, "") - new_examples = { - "input_ids": [], - "attention_mask": [], - "loss_mask": [], - "labels": [], - "pixel_values": [], - "image_flags": [], - } - for i in range(len(examples)): - messages = [] - source = examples[i]["conversations"] - - # Detect format: either role/content or from/value - def get_role_content(item): - if "role" in item and "content" in item: - return item["role"], item["content"] - elif "from" in item and "value" in item: - return item["from"], item["value"] - else: - raise ValueError(f"Unknown conversation format: {item}") - - # align role to user-assistant format - def convert_role(role): - role_map = { - "human": "user", - "gpt": "assistant", - } - return role_map[role.lower()] if role.lower() in role_map else role.lower() - - for sentence in source: - role, content = get_role_content(sentence) - new_role = convert_role(role) - messages.append({"role": new_role, "content": content}) - conversation = tokenizer.apply_chat_template( - messages, - tokenize=False, - add_generation_prompt=False, - ) - - img_filename = os.path.join(img_dir, examples[i]["image"]) - img = Image.open(img_filename) - output = processor(images=img, text=conversation, return_tensors="pt") - input_ids = output.input_ids[0] - attention_mask = output.attention_mask[0] - loss_mask = torch.ones_like(input_ids) - labels = torch.cat([input_ids[1:], torch.tensor([IGNORE_TOKEN_ID], dtype=input_ids.dtype)]) - # TODO: add labels and answer-only loss masking? - - new_examples["input_ids"].append(input_ids) - new_examples["attention_mask"].append(attention_mask) - new_examples["loss_mask"].append(loss_mask) - new_examples["labels"].append(labels) - new_examples["pixel_values"].append(output.pixel_values) - new_examples["image_flags"].append( - torch.ones((output.pixel_values.shape[0],), dtype=torch.int64) - ) - return new_examples +class OfflineSupervisedDataset(Dataset): + """Offline dataset for supervised fine-tuning. -class SupervisedDataset(Dataset): - """Dataset for supervised fine-tuning. + This dataset loads data on-the-fly from pre-processed .pt data files. Args: - raw_data (list): A list of raw data examples. - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. + dumped_files (list): A list of file paths to the dumped .pt files. """ def __init__( self, - raw_data, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, + dumped_files, ): super().__init__() - - print_rank_0("Formatting inputs...") - sources = raw_data - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess - self.data_dict = self.preprocess_fn( - sources, tokenizer, processor=vlm_processor, img_dir=img_dir - ) + self.dumped_files = dumped_files def __len__(self): - return len(self.data_dict["input_ids"]) + return len(self.dumped_files) def __getitem__(self, i) -> dict[str, torch.Tensor]: - return {k: self.data_dict[k][i] for k in self.data_dict} - - -class LazySupervisedDataset(Dataset): - """Lazy dataset for supervised fine-tuning. - - This dataset loads data on-the-fly when requested, which can be memory-efficient but slower. - - Args: - raw_data (list): A list of raw data examples. - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - """ - - def __init__( - self, - raw_data, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, - ): - super().__init__() - print_rank_0("Formatting inputs...Skip in lazy mode") - self.tokenizer = tokenizer - self.raw_data = raw_data - self.cached_data_dict = {} - self.vlm_processor = vlm_processor - self.img_dir = img_dir - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess + offline_data = torch.load(self.dumped_files[i]) - def __len__(self): - return len(self.raw_data) - - def __getitem__(self, i) -> dict[str, torch.Tensor]: - if i in self.cached_data_dict: - return self.cached_data_dict[i] - ret = self.preprocess_fn( - [self.raw_data[i]], self.tokenizer, processor=self.vlm_processor, img_dir=self.img_dir - ) - ret = {k: ret[k][0] for k in ret} - self.cached_data_dict[i] = ret + labels = torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID) + labels[..., :-1] = offline_data["input_ids"][..., 1:] + ret = { + "input_ids": offline_data["input_ids"], + "base_model_hidden_states": offline_data["hidden_states"], + "aux_hidden_states": offline_data["aux_hidden_states"], + "attention_mask": torch.ones_like(offline_data["input_ids"]), + "loss_mask": torch.ones_like(offline_data["input_ids"]), + "labels": labels, + } return ret -class OfflineSupervisedDataset(Dataset): - """Lazy offline dataset for supervised fine-tuning. +class EagleOfflineDataCollator: + """Data collator that truncate or pads data for offline training.""" - This dataset loads data on-the-fly from pre-processed .pt data files as well as - input conversations in JSON format. + def __init__(self, train_len): + self.train_len = train_len - Args: - data_entries (list): A list of tuples (raw_data_example, file_path). - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - """ + def _pad_or_truncate(self, x: torch.Tensor, length: int, dim: int = 0): + """Pad or truncate a tensor to length along a given dimension.""" + dim = dim % x.ndim # support negative dimension - def __init__( - self, - data_entries, - tokenizer: transformers.PreTrainedTokenizer, - vlm_processor=None, - img_dir=None, - ): - super().__init__() - print_rank_0("Formatting inputs...Skip in offline mode") - self.tokenizer = tokenizer - self.data_entries = data_entries - self.vlm_processor = vlm_processor - self.img_dir = img_dir - self.preprocess_fn = preprocess_vlm if vlm_processor is not None else preprocess + # allocate output tensor + out_shape = list(x.shape) + out_shape[dim] = length + out = x.new_zeros(out_shape) - # Does not cache the hidden states, as those have an extremely large memory footprint. - self.cached_data_dict = {} + # consturct copy slice + slc = [slice(None)] * x.ndim + slc[dim] = slice(0, min(length, x.size(dim))) - def __len__(self): - return len(self.data_entries) + # populate output tensor + out[tuple(slc)] = x[tuple(slc)] + return out - def __getitem__(self, i) -> dict[str, torch.Tensor]: - # Load the conversational data, using the cache - raw_data, offline_file_path = self.data_entries[i] - # Extend the data sample with the hidden states from the .pt file - max_length = self.tokenizer.model_max_length - offline_data = torch.load(offline_file_path) - offline_data["input_ids"] = offline_data["input_ids"][:max_length] - offline_data["hidden_states"] = offline_data["hidden_states"][:max_length, :] - offline_data["aux_hidden_states"] = offline_data["aux_hidden_states"][:max_length, :] + def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: + base_batch = { + k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) + for k in ["input_ids", "attention_mask", "loss_mask", "labels"] + } - ret = { - "input_ids": offline_data["input_ids"], - "attention_mask": torch.ones_like(offline_data["input_ids"]), - "loss_mask": torch.ones_like(offline_data["input_ids"]), - "labels": torch.full_like(offline_data["input_ids"], IGNORE_TOKEN_ID), - "kwargs": { - "base_model_outputs": { - "base_model_hidden_states": offline_data["hidden_states"], - "aux_hidden_states": offline_data["aux_hidden_states"], - } - }, + base_model_outputs = { + k: torch.stack([self._pad_or_truncate(item[k], self.train_len) for item in features]) + for k in ["base_model_hidden_states", "aux_hidden_states"] } - return ret + + batch = { + **base_batch, + "base_model_outputs": base_model_outputs, + } + return batch def make_eagle_supervised_data_module( tokenizer: transformers.PreTrainedTokenizer, data_args, - max_length=None, + train_len=None, ) -> dict: - """Make dataset and collator for supervised fine-tuning. - - Args: - tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use for data preprocessing. - data_args: Data arguments. + if data_args.offline_data_path is None: + train_dataset = ShardedDataset("json", data_files=data_args.data_path) + + if not data_args.vlm_processor: + data_collator = LanguageDataCollator( + tokenizer=tokenizer, + train_len=train_len, + return_labels=True, + ) + else: + data_collator = VisionLanguageDataCollator( + processor=data_args.vlm_processor, + train_len=train_len, + local_image_path=data_args.vlm_img_dir, + return_labels=True, + ) - Returns: - dict: A dictionary containing train and eval datasets. - """ - if data_args.vlm_processor: - vlm_processor = AutoProcessor.from_pretrained( - data_args.vlm_processor, trust_remote_code=True, use_fast=True - ) - vlm_img_dir = data_args.vlm_img_dir else: - vlm_processor, vlm_img_dir = None, None - # Load the conversations from the source file - print_rank_0("Loading input conversations...") - data_json = [] - data_path_p = Path(data_args.data_path) - if data_path_p.is_dir(): - # Load all .jsonl files in the directory and combine them - for jsonl_file in sorted(data_path_p.glob("*.jsonl")): - with open(jsonl_file) as f: - data_json.extend(json.loads(line) for line in f) - else: - with open(data_args.data_path) as f: - if data_args.data_path.endswith("jsonl"): - data_json = [json.loads(line) for line in f] - else: - data_json = json.load(f) - - if data_args.offline_data_path is not None: print_rank_0("Loading pre-processed data for offline training...") - dataset_cls = OfflineSupervisedDataset + assert not data_args.vlm_processor, "Offline data is not supported for VLM." - # Glob for all .pt files in the data_path directory - assert data_args.offline_data_path is not None, ( - "offline_data_path must be provided for offline training." - ) offline_data_path = Path(data_args.offline_data_path) - # Collect all pt file paths - all_files = {str(p) for p in offline_data_path.glob("*.pt")} - all_files |= {str(p) for p in offline_data_path.glob("**/*.pt")} - if not all_files: + dumped_files = [str(p) for p in offline_data_path.glob("*.pt")] + if not dumped_files: raise ValueError(f"No .pt files found in {data_args.offline_data_path}") - # Build a map from conv_id to file_path for fast lookup - print("building conv_id_to_file map...") - conv_id_to_file = {} - for pt_path in all_files: - pt_name = Path(pt_path).name - # Expect conv_id.pt - if pt_name.endswith(".pt"): - conv_id = pt_name[:-3] - conv_id_to_file[conv_id] = pt_path - - valid_entries = [] - print("filtering valid entries...") - for entry in data_json: - conv_id = entry.get("conversation_id") - if conv_id is None: - conv_id = entry.get("uuid") - if conv_id is None: - conv_id = entry.get("id") - if conv_id is None: - raise ValueError(f"Conversation ID required but not found for entry {entry}") - - file_path = conv_id_to_file.get(str(conv_id)) - if file_path is None: - continue - valid_entries.append((entry, file_path)) - - if len(valid_entries) == 0: - msg = """No valid files found in the offline data path that match the conversation IDs - in the provided data json. Please ensure that the offline data path is correct and - contains .pt files named after the conversation IDs, and that the input conversations - json has the correct format (with 'conversation_id' or 'id' fields).""" - raise ValueError(msg) - elif len(valid_entries) < len(data_json): - print_rank_0( - f"Warning: Only {len(valid_entries)} out of {len(data_json)} conversations" - " have corresponding .pt files in the offline data path. Continuing..." - ) - - num_train = int(len(valid_entries) * 0.95) - train_dataset = dataset_cls( - valid_entries[:num_train], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - eval_dataset = dataset_cls( - valid_entries[num_train:], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - - data_collator = DataCollatorForOffline(max_length=max_length) - else: - print_rank_0("Loading input conversations...") - dataset_cls = LazySupervisedDataset if data_args.lazy_preprocess else SupervisedDataset - - train_dataset = dataset_cls( - data_json[: int(len(data_json) * 0.95)], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - eval_dataset = dataset_cls( - data_json[int(len(data_json) * 0.95) :], - tokenizer=tokenizer, - vlm_processor=vlm_processor, - img_dir=vlm_img_dir, - ) - - data_collator = DataCollatorWithPadding(max_length=max_length) + train_dataset = OfflineSupervisedDataset(dumped_files) + data_collator = EagleOfflineDataCollator(train_len=train_len) return { "train_dataset": train_dataset, - "eval_dataset": eval_dataset, "data_collator": data_collator, } -class DataCollatorWithPadding: - def __init__(self, max_length): - self.max_length = max_length - - def paddingtensor2d(self, intensors, length): - n, dim = intensors.shape - if n > length: - return intensors[:length, :] - padding_tensor = torch.zeros(length - n, dim, dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor)) - return outtensors - - def paddingtensor(self, intensors, length): - if intensors.shape[0] > length: - return intensors[:length] - padding_tensor = torch.zeros(length - intensors.shape[0], dtype=intensors.dtype) - outtensors = torch.cat((intensors, padding_tensor)) - return outtensors - - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: - batch_input_ids = torch.stack( - [self.paddingtensor(item["input_ids"], self.max_length) for item in features] - ) - batch_attention_mask = torch.stack( - [self.paddingtensor(item["attention_mask"], self.max_length) for item in features] - ) - batch_loss_mask = torch.stack( - [self.paddingtensor(item["loss_mask"], self.max_length) for item in features] - ) - - batch_labels = torch.stack( - [self.paddingtensor(item["labels"], self.max_length) for item in features] - ) - - batch = { - "input_ids": batch_input_ids, - "attention_mask": batch_attention_mask, - "loss_mask": batch_loss_mask, - "labels": batch_labels, - } - - # Collate VLM data - if "pixel_values" in features[0]: - # pixel values and image flags should be flattened inside a batch - batch["pixel_values"] = torch.cat([item["pixel_values"] for item in features], dim=0) - batch["image_flags"] = torch.cat([item["image_flags"] for item in features], dim=0) - - return batch - - -class DataCollatorForOffline(DataCollatorWithPadding): - def __call__(self, features: list[dict[str, Any]]) -> dict[str, Any]: - base_batch = super().__call__(features) - if "kwargs" not in features[0]: - raise ValueError("No kwargs found in batch features. Offline data required.") - - features = [item["kwargs"]["base_model_outputs"] for item in features] - - batch_hidden_states = torch.stack( - [ - self.paddingtensor2d(item["base_model_hidden_states"], self.max_length) - for item in features - ] - ) - batch_aux_hidden_states = torch.stack( - [self.paddingtensor2d(item["aux_hidden_states"], self.max_length) for item in features] - ) - - batch = { - **base_batch, - "base_model_outputs": { - "base_model_hidden_states": batch_hidden_states, - "aux_hidden_states": batch_aux_hidden_states, - }, - } - - return batch - - class EagleTrainerWithAccLog(Trainer): """Wrapper around Trainer that logs training accuracy.""" diff --git a/examples/speculative_decoding/launch_train.sh b/examples/speculative_decoding/launch_train.sh index c937d5b09..c0b9ea00e 100755 --- a/examples/speculative_decoding/launch_train.sh +++ b/examples/speculative_decoding/launch_train.sh @@ -102,6 +102,14 @@ while [ $# -gt 0 ]; do if [[ "$1" != *=* ]]; then shift; fi DP_SHARD_SIZE="${1#*=}" ;; + --log_steps*) + if [[ "$1" != *=* ]]; then shift; fi + LOG_STEPS="${1#*=}" + ;; + --draft_vocab_cache*) + if [[ "$1" != *=* ]]; then shift; fi + DRAFT_VOCAB_CACHE="${1#*=}" + ;; *) >&2 printf "Error: Invalid argument ${1#*=}\n" exit 1 @@ -138,6 +146,8 @@ AR_VALIDATE_STEPS=${AR_VALIDATE_STEPS:-1000} ESTIMATE_AR=${ESTIMATE_AR:-False} CP_SIZE=${CP_SIZE:-1} DP_SHARD_SIZE=${DP_SHARD_SIZE:-$((GPU_COUNT/CP_SIZE))} +LOG_STEPS=${LOG_STEPS:-100} +DRAFT_VOCAB_CACHE=${DRAFT_VOCAB_CACHE:-""} if [[ "$MODE" == "medusa" ]]; then SPECULATIVE_ARGS="--medusa_num_heads $MEDUSA_NUM_HEADS --medusa_num_layers $MEDUSA_NUM_LAYERS" @@ -179,6 +189,13 @@ else fi +if [[ "$DRAFT_VOCAB_CACHE" != "" ]]; then + DRAFT_VOCAB_CACHE_ARGS="--draft_vocab_cache $DRAFT_VOCAB_CACHE" +else + DRAFT_VOCAB_CACHE_ARGS="" +fi + + # Disable tokenizers parallelism to avoid warning export TOKENIZERS_PARALLELISM=False CMD="accelerate launch --mixed_precision bf16 main.py \ @@ -201,12 +218,13 @@ CMD="accelerate launch --mixed_precision bf16 main.py \ --weight_decay 0.0 \ --warmup_steps 100 \ --lr_scheduler_type linear \ - --logging_steps 100 \ + --logging_steps $LOG_STEPS \ --tf32 True \ --data_path $DATA \ --disable_tqdm $DISABLE_TQDM \ --estimate_ar $ESTIMATE_AR \ --ar_validate_steps $AR_VALIDATE_STEPS \ + $DRAFT_VOCAB_CACHE_ARGS \ $VLM_ARGS \ $OFFLINE_TRAINING_ARGS \ $SPECULATIVE_ARGS \ diff --git a/examples/speculative_decoding/main.py b/examples/speculative_decoding/main.py index 8706ca049..4cfa62f3d 100644 --- a/examples/speculative_decoding/main.py +++ b/examples/speculative_decoding/main.py @@ -76,9 +76,9 @@ class DataArguments: }, ) lazy_preprocess: bool = True - draft_vocab_cache_dir: str = field( - default="draft_vocab_cache", - metadata={"help": "Path to the d2t cache directory."}, + draft_vocab_cache: str | None = field( + default=None, + metadata={"help": "Path to d2t.pt cache file."}, ) vlm_img_dir: str = field(default=None, metadata={"help": "Path to the VLM image directory."}) vlm_processor: str = field(default=None, metadata={"help": "Path to the VLM processor."}) @@ -97,7 +97,7 @@ class TrainingArguments(transformers.TrainingArguments): ) dataloader_drop_last: bool = field(default=True) bf16: bool = field(default=True) - mode: Literal["eagle1", "eagle3", "medusa"] = "eagle3" + mode: Literal["eagle3", "medusa"] = "eagle3" estimate_ar: bool = field( default=False, metadata={"help": "Whether to estimate AR during training for logging."} ) @@ -147,30 +147,35 @@ def train(): training_args.parallelism_config.sp_backend = None print_rank_0(f"arguments: {model_args}, {training_args}, {medusa_args}, {eagle_args}") - # Detecting last checkpoint. - last_checkpoint = None - if os.path.isdir(training_args.output_dir): - last_checkpoint = get_last_checkpoint(training_args.output_dir) + # Detect checkpoint to resume from + last_checkpoint = ( + get_last_checkpoint(training_args.output_dir) + if os.path.isdir(training_args.output_dir) + else None + ) + if last_checkpoint: print_rank_0(f"Last checkpoint detected: {last_checkpoint}") - checkpoint = None - if training_args.resume_from_checkpoint is not None: - checkpoint = training_args.resume_from_checkpoint - elif last_checkpoint is not None: - checkpoint = last_checkpoint + checkpoint = training_args.resume_from_checkpoint or last_checkpoint use_offline_training = data_args.offline_data_path is not None + model_config = transformers.AutoConfig.from_pretrained( + model_args.model_name_or_path, trust_remote_code=True + ) + if "vl" in model_config.model_type.lower(): + model_cls = transformers.AutoModelForVision2Seq + else: + model_cls = transformers.AutoModelForCausalLM + if checkpoint: - model = transformers.AutoModelForCausalLM.from_pretrained( - checkpoint, torch_dtype="auto", trust_remote_code=True - ) + model = model_cls.from_pretrained(checkpoint, torch_dtype="auto", trust_remote_code=True) tokenizer = transformers.AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) else: # To avoid OOM for large models, we load and convert model on CPU first. # Model will be moved to GPU during HF trainer.init(). offline_kwargs = {"num_hidden_layers": 0} if use_offline_training else {} - model = transformers.AutoModelForCausalLM.from_pretrained( + model = model_cls.from_pretrained( model_args.model_name_or_path, torch_dtype="auto", device_map="cpu", @@ -180,79 +185,41 @@ def train(): if use_offline_training: # When doing offline training, we need to set num_hidden_layers # since we override it when loading the model for space savings - model_config = transformers.AutoConfig.from_pretrained( - model_args.model_name_or_path, trust_remote_code=True - ) model.config.num_orig_hidden_layers = model_config.num_hidden_layers tokenizer = transformers.AutoTokenizer.from_pretrained( model_args.model_name_or_path, model_max_length=training_args.training_seq_len, trust_remote_code=True, ) - if tokenizer.chat_template is None: - tokenizer.chat_template = ( - "{%- for message in messages %}" - "{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}" - "{%- endfor %}" - ) - if tokenizer.pad_token_id is None: - tokenizer.pad_token_id = tokenizer.eos_token_id - if training_args.mode == "medusa": config = { "medusa_num_heads": medusa_args.medusa_num_heads, "medusa_num_layers": medusa_args.medusa_num_layers, } mtsp.convert(model, [("medusa", config)]) - elif training_args.mode in ["eagle1", "eagle3"]: - from modelopt.torch.speculative.config import ( - default_eagle_config, - eagle3_default_config, - kimik2_eagle_default_config, + elif training_args.mode == "eagle3": + custom_config = ( + json.load(open(eagle_args.eagle_config)) if eagle_args.eagle_config else {} ) - if eagle_args.eagle_decoder_type == "kimik2": - eagle_architecture_config = kimik2_eagle_default_config - else: - eagle_architecture_config = { - "eagle1": default_eagle_config, - "eagle3": eagle3_default_config, - }[training_args.mode] - - if eagle_args.eagle_config: - with open(eagle_args.eagle_config) as f: - custom_config = json.load(f) - eagle_architecture_config.update(custom_config) - config = { "eagle_decoder_type": eagle_args.eagle_decoder_type, "eagle_offline": use_offline_training, - "eagle_architecture_config": eagle_architecture_config, + "eagle_architecture_config": custom_config, + "draft_vocab_cache": data_args.draft_vocab_cache, } mtsp.convert(model, [("eagle", config)]) - # read draft vocab cache - if model.eagle_config.draft_vocab_size < model.eagle_config.vocab_size: - try: - model_name = os.path.basename(os.path.normpath(model_args.model_name_or_path)) - vocab_cache_path = os.path.join( - data_args.draft_vocab_cache_dir, model_name, "d2t.pt" - ) - vocab_cache = torch.load(vocab_cache_path) - model.eagle_module.d2t = vocab_cache - print_rank_0(f"Loaded draft vocab cache from {vocab_cache_path}.") - except Exception as e: - raise e else: raise Exception(f"{training_args.mode} is not supported!") print_rank_0("Loading dataset...") if training_args.mode == "medusa": data_module = make_medusa_supervised_data_module(tokenizer, data_args) - elif training_args.mode in ["eagle1", "eagle3"]: + elif training_args.mode == "eagle3": data_module = make_eagle_supervised_data_module( - tokenizer, data_args, max_length=training_args.training_seq_len + tokenizer, data_args, train_len=training_args.training_seq_len ) trainer = EagleTrainerWithAccLog( diff --git a/modelopt/torch/speculative/config.py b/modelopt/torch/speculative/config.py index 41987d4e4..4d5d3b15e 100644 --- a/modelopt/torch/speculative/config.py +++ b/modelopt/torch/speculative/config.py @@ -105,3 +105,8 @@ class EagleConfig(ModeloptBaseConfig): default="llama", description=("The class of eagle decoder to use. Available options: llama, kimik2"), ) + + draft_vocab_cache: str | None = ModeloptField( + default=None, + description=("Path to d2t.pt cache file."), + ) diff --git a/modelopt/torch/speculative/eagle/conversion.py b/modelopt/torch/speculative/eagle/conversion.py index ffaa195f2..978b25562 100644 --- a/modelopt/torch/speculative/eagle/conversion.py +++ b/modelopt/torch/speculative/eagle/conversion.py @@ -20,6 +20,7 @@ from modelopt.torch.opt.conversion import ModelLikeModule from modelopt.torch.opt.dynamic import _DMRegistryCls from modelopt.torch.opt.mode import ConvertReturnType, MetadataDict +from modelopt.torch.speculative.config import eagle3_default_config, kimik2_eagle_default_config from ..config import EagleConfig @@ -38,6 +39,14 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu EagleDMRegistry.register({original_cls: "base_model_class"})(EagleDMRegistry[cls]) break + # merge custom config with default config + default_arch_config = { + "llama": eagle3_default_config, + "kimik2": kimik2_eagle_default_config, + }[config.eagle_decoder_type] + custom_config = config.eagle_architecture_config + config.eagle_architecture_config = {**default_arch_config, **custom_config} + eagle_model = EagleDMRegistry.convert(model) eagle_model.modify( eagle_offline=config.eagle_offline, @@ -49,6 +58,7 @@ def convert_to_eagle_model(model: nn.Module, config: EagleConfig) -> ConvertRetu eagle_loss_decay_factor=config.eagle_loss_decay_factor, eagle_architecture_config=config.eagle_architecture_config, eagle_decoder_type=config.eagle_decoder_type, + draft_vocab_cache=config.draft_vocab_cache, ) # no metadata, all specified via config. diff --git a/modelopt/torch/speculative/eagle/eagle_model.py b/modelopt/torch/speculative/eagle/eagle_model.py index d54fdc843..c616e7b7c 100644 --- a/modelopt/torch/speculative/eagle/eagle_model.py +++ b/modelopt/torch/speculative/eagle/eagle_model.py @@ -35,6 +35,7 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + draft_vocab_cache, ): """Base Eagle Model modify function. Child class should implement the details.""" self.eagle_offline = eagle_offline @@ -45,3 +46,4 @@ def modify( self.eagle_reuse_base_decoder = eagle_reuse_base_decoder self.eagle_loss_decay_factor = eagle_loss_decay_factor self.eagle_decoder_type = eagle_decoder_type + self.draft_vocab_cache = draft_vocab_cache diff --git a/modelopt/torch/speculative/plugins/megatron_eagle.py b/modelopt/torch/speculative/plugins/megatron_eagle.py index e37e8f931..499ae6697 100644 --- a/modelopt/torch/speculative/plugins/megatron_eagle.py +++ b/modelopt/torch/speculative/plugins/megatron_eagle.py @@ -16,6 +16,7 @@ """Plugin to add EAGLE support for Megatron-Core GPT model.""" import copy +import os import warnings from contextlib import contextmanager @@ -54,6 +55,7 @@ from megatron.core.transformer.utils import sharded_state_dict_default from megatron.core.utils import make_tp_sharded_tensor_for_checkpoint from packaging.version import Version +from torch._subclasses.fake_tensor import FakeTensorMode from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel @@ -693,6 +695,7 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + draft_vocab_cache, ): if self.config.pipeline_model_parallel_size > 1: warnings.warn( @@ -715,6 +718,7 @@ def modify( eagle_loss_decay_factor=eagle_loss_decay_factor, eagle_architecture_config=eagle_architecture_config, eagle_decoder_type=eagle_decoder_type, + draft_vocab_cache=draft_vocab_cache, ) # sequence_parallel is not used in offline eagle @@ -731,11 +735,18 @@ def modify( self.eagle_config.hidden_size = self.config.hidden_size self.eagle_config.vocab_size = self.vocab_size self.eagle_config.max_sequence_length = self.max_sequence_length - self.eagle_config.draft_vocab_size = ( - self.vocab_size - if self.eagle_config.draft_vocab_size is None - else self.eagle_config.draft_vocab_size - ) + + if draft_vocab_cache is not None: + if not os.path.isfile(draft_vocab_cache): + raise FileNotFoundError( + f"Draft vocab cache provided but not found: {draft_vocab_cache}" + ) + # Read draft_vocab_size from d2t without loading tensor + with FakeTensorMode(): + d2t = torch.load(draft_vocab_cache, mmap=True) + self.eagle_config.draft_vocab_size = d2t.shape[0] + else: + self.eagle_config.draft_vocab_size = self.vocab_size if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: assert eagle_self_logit_distillation, ( diff --git a/modelopt/torch/speculative/plugins/transformers.py b/modelopt/torch/speculative/plugins/transformers.py index 3090297aa..a3683dedf 100644 --- a/modelopt/torch/speculative/plugins/transformers.py +++ b/modelopt/torch/speculative/plugins/transformers.py @@ -31,6 +31,7 @@ import contextlib import copy +import os from typing import Any import torch @@ -49,6 +50,8 @@ from transformers.utils import ModelOutput from transformers.utils.quantization_config import QuantizationMethod +from modelopt.torch.utils import print_rank_0 + from ..eagle.conversion import EagleDMRegistry from ..eagle.eagle_model import EagleModel from ..eagle.utils import expand_mask, make_causal_mask @@ -227,7 +230,7 @@ def forward(self, x): class EagleModule(nn.Module): """Eagle module used in EAGLE model.""" - def __init__(self, config, decoder_layer_cls, bias=False): + def __init__(self, config, decoder_layer_cls, bias=False, draft_vocab_cache=None): """Init function for EagleModule.""" super().__init__() self.config = config @@ -238,17 +241,27 @@ def __init__(self, config, decoder_layer_cls, bias=False): if config.use_last_layernorm: self.norm = LlamaRMSNorm(config.hidden_size, config.rms_norm_eps) - # Optionally, we use a smaller vocab table for eagle module - if config.draft_vocab_size != config.vocab_size or config.has_lm_head: - # Need an extra lm_head for eagle module since vocab size is reduced. - assert config.draft_vocab_size <= config.vocab_size, ( - "EAGLE module's vocab size should be <= base model vocab size!" + # Load draft vocab cache if provided + if draft_vocab_cache is not None: + if not os.path.isfile(draft_vocab_cache): + raise FileNotFoundError( + f"Draft vocab cache provided but not found: {draft_vocab_cache}" + ) + d2t = torch.load(draft_vocab_cache) + if d2t.shape[0] > config.vocab_size: + raise ValueError( + f"Draft vocab cache size {d2t.shape[0]} is greater than base model vocab size {config.vocab_size}" + ) + print_rank_0( + f"Setting draft_vocab_size to {d2t.shape[0]} due to draft_vocab_cache provided." ) + config.draft_vocab_size = d2t.shape[0] + self.register_buffer("d2t", d2t) + print_rank_0(f"Loaded draft_vocab_cache from {draft_vocab_cache}.") + else: + config.draft_vocab_size = config.vocab_size - # Initialize the buffers to zero. - # Their values depend on specific tokenzier and calibrate dataset, and should be set in training script. - if config.draft_vocab_size < config.vocab_size: - self.register_buffer("d2t", torch.zeros(config.draft_vocab_size, dtype=torch.int64)) + if config.draft_vocab_size != config.vocab_size or config.has_lm_head: self.lm_head = nn.Linear( config.hidden_size, config.draft_vocab_size, @@ -342,6 +355,7 @@ def forward( past_key_values: Cache | None = None, use_cache: bool | None = None, output_attentions: bool | None = False, + position_embeddings: torch.Tensor | None = None, ): """Forward function for EagleModule.""" batch_size, seq_length, _ = hidden_states.shape @@ -372,15 +386,16 @@ def forward( else: # EAGLE-1 hidden_states = self.fc(torch.cat((inputs_embeds, hidden_states), dim=-1)) - if self.config.eagle_decoder_type == "llama": - # Lazy init rope to avoid save/load meta tensor error - if not hasattr(self, "rotary_emb"): - self.rotary_emb = LlamaRotaryEmbedding( - config=self.config, device=hidden_states.device - ) - position_embeddings = self.rotary_emb(hidden_states, position_ids) - else: - position_embeddings = None + if position_embeddings is None: + if self.config.eagle_decoder_type == "llama": + # Lazy init rope to avoid save/load meta tensor error + if not hasattr(self, "rotary_emb"): + self.rotary_emb = LlamaRotaryEmbedding( + config=self.config, device=hidden_states.device + ) + position_embeddings = self.rotary_emb(hidden_states, position_ids) + else: + position_embeddings = None for decoder_layer in self.layers: layer_outputs = decoder_layer( @@ -425,16 +440,26 @@ def _base_model_lm_head(self): @property def _base_llm_config(self): """Return the llm config for the base model, from LLM or VLM.""" - return self.config.llm_config if hasattr(self.config, "llm_config") else self.config + return ( + getattr(self.config, "text_config", None) + or getattr(self.config, "llm_config", None) + or self.config + ) def _find_base_model_parts(self): """Find model parts from different models and set base_{part}_path attributes.""" base_model_parts_mapping = { - "base_model_path": ["model", "backbone", "language_model.backbone"], + "base_model_path": [ + "model.language_model", + "model", + "backbone", + "language_model.backbone", + ], "base_model_embeddings_path": [ "model.embed_tokens", "backbone.embeddings", "language_model.backbone.embeddings", + "model.language_model.embed_tokens", ], "base_model_lm_head_path": ["lm_head", "language_model.lm_head"], } @@ -478,6 +503,10 @@ def _collect_aux_hidden_states_forward_hook(self, module, input, output) -> None ) self._aux_hidden_states.append(hidden_states) + def _collect_position_ids_forward_hook(self, module, args, kwargs, output) -> None: + """Collect position embeddings from base model intermediate layers, save them in attribute.""" + self._pos_embeddings = tuple(t.clone().detach() for t in kwargs["position_embeddings"]) + def pop_and_gather_aux_hiddens(self): """Pop auxiliary hidden states from base model and gather them on the draft model device.""" # In PTQ, forward method will be called with try and except to find max batch size. @@ -514,6 +543,7 @@ def modify( eagle_loss_decay_factor, eagle_architecture_config, eagle_decoder_type, + draft_vocab_cache, ): """Constructor. @@ -530,6 +560,7 @@ def modify( eagle_loss_decay_factor=eagle_loss_decay_factor, eagle_architecture_config=eagle_architecture_config, eagle_decoder_type=eagle_decoder_type, + draft_vocab_cache=draft_vocab_cache, ) if eagle_decoder_type == "llama": @@ -544,9 +575,6 @@ def modify( self.eagle_config.hidden_size = self._base_llm_config.hidden_size self.eagle_config.vocab_size = self._base_llm_config.vocab_size self.eagle_config.max_position_embeddings = self._base_llm_config.max_position_embeddings - self.eagle_config.draft_vocab_size = getattr( - self.eagle_config, "draft_vocab_size", self.eagle_config.vocab_size - ) if self.eagle_config._attn_implementation is None: self.eagle_config._attn_implementation = "sdpa" @@ -559,21 +587,13 @@ def modify( ): self.config.quantization_config.quantization_config.ignore.append("re:.*eagle_module.*") - # Use default aux_hidden_state layers if use_aux_hidden_state is True - # but no layer id is given + # Set default aux_hidden_state layers if ( self.eagle_config.use_aux_hidden_state and len(self.eagle_config.eagle_aux_hidden_state_layer_ids) == 0 ): self._set_default_aux_hidden_state_layers() - if self._base_llm_config.hidden_size != self.eagle_config.hidden_size: - raise ValueError( - "EAGLE module hidden size " - f"{self.eagle_config.hidden_size} must match base model hidden size " - f"{self._base_llm_config.hidden_size}!" - ) - # Freeze all parameters if self.eagle_freeze_base_model: for name, param in self.named_parameters(): @@ -582,6 +602,7 @@ def modify( self.eagle_module = EagleModule( self.eagle_config, decoder_cls, + draft_vocab_cache=draft_vocab_cache, ) # find base model, lm head, and embeddings paths @@ -595,6 +616,11 @@ def modify( if layer_idx in self.eagle_config.eagle_aux_hidden_state_layer_ids: layer.register_forward_hook(self._collect_aux_hidden_states_forward_hook) + # hardcode for qwen3vl + self.model.language_model.layers[0].register_forward_hook( + self._collect_position_ids_forward_hook, with_kwargs=True + ) + # delete base model layers for offline training if eagle_offline: self._base_model._modules.pop("layers") @@ -699,7 +725,7 @@ def _compute_ttt_attention_mask( dtypemin = torch.finfo(self._base_llm_config.dtype).min q_len = seq_length kv_len = seq_length * (1 + ttt_step) - if self.eagle_module.config._attn_implementation == "flex_attention": + if self.eagle_config._attn_implementation == "flex_attention": # Return block mask for flex attention block_mask = create_block_mask(msk_func, B=None, H=None, Q_LEN=q_len, KV_LEN=kv_len) return block_mask @@ -718,37 +744,6 @@ def _compute_ttt_attention_mask( tensor_mask = tensor_mask.repeat(batch_size, 1, 1, 1) return tensor_mask - def _llm_or_vlm_embedding(self, input_ids, kwargs): - """Return input embeddings with possibly vision embeddings for VLM.""" - tok_embeds = self._base_model_embeddings(input_ids) - - # LLM only have token embeddings - if "pixel_values" not in kwargs: - return tok_embeds - - # Otherwise, insert vision embeddings in tok_embeds - if self.config.model_type == "NemotronH_Nano_VL_V2": - vit_embeds = self.extract_feature(kwargs["pixel_values"]) - vit_embeds = vit_embeds[kwargs["image_flags"] == 1] - bs, seq_len, hid_size = tok_embeds.shape - tok_embeds = tok_embeds.reshape(bs * seq_len, hid_size) - input_ids = input_ids.reshape(bs * seq_len) - selected = input_ids == self.img_context_token_id - try: - tok_embeds[selected] = tok_embeds[selected] * 0.0 + vit_embeds.reshape(-1, hid_size) - except Exception as e: - vit_embeds = vit_embeds.reshape(-1, hid_size) - print( - f"warning: {e}, tok_embeds[selected].shape={tok_embeds[selected].shape}, " - f"vit_embeds.shape={vit_embeds.shape}" - ) - n_token = selected.sum() - tok_embeds[selected] = tok_embeds[selected] * 0.0 + vit_embeds[:n_token] - del vit_embeds - return tok_embeds.reshape(bs, seq_len, hid_size) - else: - raise ValueError(f"VLM model type {self.config.model_type} not supported") - def _base_model_forward( self, input_ids, @@ -769,6 +764,7 @@ def _base_model_forward( **kwargs, ) past_key_values = getattr(outputs, "past_key_values", None) + base_input_embeds = outputs.hidden_states[0] base_model_hidden_states = outputs.hidden_states[-1] base_model_logits = outputs.logits @@ -780,7 +776,13 @@ def _base_model_forward( labels = labels.view(-1) base_model_loss = loss_fct(loss_logits, labels) - return base_model_hidden_states, base_model_logits, base_model_loss, past_key_values + return ( + base_input_embeds, + base_model_hidden_states, + base_model_logits, + base_model_loss, + past_key_values, + ) def _map_logits_to_draft_vocab(self, full_logits): reverse_mapping = ( @@ -804,6 +806,7 @@ def _eagle_forward( position_ids=position_ids, use_cache=True, past_key_values=eagle_cache, + position_embeddings=self._pos_embeddings, ) eagle_lm_head = ( self.eagle_module.lm_head @@ -854,7 +857,12 @@ def forward( assert past_key_values is None, "past_key_values should be None in training" if loss_mask is None: - loss_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device) + # By default, mask out padding tokens in loss computation + loss_mask = ( + attention_mask.clone().detach() + if attention_mask is not None + else torch.ones_like(input_ids, dtype=torch.bool) + ) # ====First, we run base model forward==== if "base_model_outputs" in kwargs: @@ -867,16 +875,20 @@ def forward( base_model_logits = self.lm_head(base_model_hidden_states) base_model_loss, past_key_values = None, None else: - base_model_hidden_states, base_model_logits, base_model_loss, past_key_values = ( - self._base_model_forward( - input_ids, - attention_mask, - position_ids, - past_key_values, - self.eagle_freeze_base_model, - labels, - **kwargs, - ) + ( + base_input_embeds, + base_model_hidden_states, + base_model_logits, + base_model_loss, + past_key_values, + ) = self._base_model_forward( + input_ids, + attention_mask, + position_ids, + past_key_values, + self.eagle_freeze_base_model, + labels, + **kwargs, ) if not isinstance(past_key_values, Cache): @@ -907,7 +919,7 @@ def forward( eagle_cache, ) with torch.no_grad(): - inputs_embeds = self._llm_or_vlm_embedding(eagle_input_ids, kwargs) + inputs_embeds = base_input_embeds.roll(-1, 1) past_key_values.eagle_cache = eagle_cache @@ -1067,9 +1079,7 @@ def pseudo_speculative_generate( ) # Use SDPA attention during generation for both stability and performance - with temporary_set_config_value( - self.eagle_module.config, "_attn_implementation", "sdpa" - ): + with temporary_set_config_value(self.eagle_config, "_attn_implementation", "sdpa"): _, eagle_prenorm_h, eagle_logits, _ = self._eagle_forward( eagle_input_hidden_states, self._base_model_embeddings(eagle_ids), diff --git a/modelopt/torch/speculative/utils.py b/modelopt/torch/speculative/utils.py index d259a1fce..fc30b1f1c 100644 --- a/modelopt/torch/speculative/utils.py +++ b/modelopt/torch/speculative/utils.py @@ -42,6 +42,9 @@ def calibrate_frequent_vocab(tokenizer, text, target_vocab_size, output_file=None): """Given a calibration text, find the most common vocabs and return the mapping.""" conversations = tokenizer.apply_chat_template(text) + # Transformers5.x returns a BatchEncoding from apply_chat_template + if hasattr(conversations, "input_ids"): + conversations = conversations.input_ids counter = Counter(conversations) vocab = counter.most_common(target_vocab_size) mapping = torch.zeros(target_vocab_size, dtype=torch.int64) diff --git a/modelopt/torch/utils/plugins/transformers_dataset.py b/modelopt/torch/utils/plugins/transformers_dataset.py new file mode 100644 index 000000000..e147ebf2c --- /dev/null +++ b/modelopt/torch/utils/plugins/transformers_dataset.py @@ -0,0 +1,303 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Processing large data to tokenize for pretraining.""" + +import copy +import itertools +import os + +import torch +import transformers +from datasets import load_dataset +from transformers.trainer_pt_utils import LabelSmoother + +from modelopt.torch.utils import print_rank_0 + +REMOVE_THINK_CHAT_TEMPLATE = ( + "{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}" +) + +IGNORE_TOKEN_ID = LabelSmoother.ignore_index + + +def _sharegpt_to_openai_messages(conversations: list[dict]): + """Optionally align sharedgpt format to openai format.""" + role_mapping = { + "user": "user", + "User": "user", + "human": "user", + "assistant": "assistant", + "Assistant": "assistant", + "gpt": "assistant", + "system": "system", + "System": "system", + } + messages = [] + for msg in conversations: + role = role_mapping[msg["role"]] + content = msg["content"] + messages.append({"role": role, "content": content}) + return messages + + +class ShardedDataset(torch.utils.data.Dataset): + """Subclass of torch.utils.data.Dataset to load data from HuggingFace dataset.""" + + def __init__( + self, + name: str, + subset: str | None = None, + data_files: str | None = None, + split: str = "train", + num_shards: int = 1, + shard_index: int = 0, + num_streaming_samples: int | None = None, + ): + """Initialize the ShardedDataset.""" + self.name = name + self.subset = subset + self.split = split + self.data_files = data_files + self.num_shards = num_shards + self.shard_index = shard_index + self.num_streaming_samples = num_streaming_samples + + self._load_dataset() + + def __len__(self): + if self.num_streaming_samples is not None: + return self.num_streaming_samples + else: + return len(self._raw_samples) + + def __getitem__(self, index): + index = index // self.num_shards + + if self.num_streaming_samples is not None: + while index >= len(self._raw_samples): + self._raw_samples.append(next(self._stream_iterator)) + + return self._raw_samples[index] + + def _load_dataset(self): + dataset = load_dataset( + self.name, + self.subset, + data_files=self.data_files, + split=self.split, + # num_proc=4, # TODO: Make this configurable + streaming=self.num_streaming_samples is not None, + ) + + shard = dataset.shard(num_shards=self.num_shards, index=self.shard_index) + + if self.num_streaming_samples is not None: + self._raw_samples = [] + self._stream_samples = shard + self._stream_iterator = itertools.cycle(self._stream_samples) + else: + self._raw_samples = shard + + +class LanguageDataCollator: + """Data collator for language modeling tasks. + + Accepts samples in OpenAI or ShareGPT formats and returns + tokenized outputs with padding and truncation, including + input_ids and attention_mask. + """ + + def __init__( + self, + tokenizer: transformers.PreTrainedTokenizerBase, + train_len: int = 4096, + chat_template: str | None = None, + add_generation_prompt: bool = False, + answer_only_loss: bool = False, + json_key: str = "text", + return_labels: bool = False, + ): + """Initialize the LanguageDataset.""" + if not isinstance(tokenizer, transformers.PreTrainedTokenizerBase): + raise ValueError( + "The tokenizer must be a transformers.PreTrainedTokenizerBase but got {}".format( + type(tokenizer) + ) + ) + self.tokenizer = tokenizer + self.train_len = train_len + self.add_generation_prompt = add_generation_prompt + self.answer_only_loss = answer_only_loss + self.json_key = json_key + self.return_labels = return_labels + + if chat_template is not None: + self.tokenizer.chat_template = chat_template + else: + self._post_process_chat_template() + + self._post_process_tokenizer() + if self.tokenizer.chat_template is None: + raise ValueError("No valid chat template!") + + def _post_process_tokenizer(self): + if self.tokenizer.pad_token_id is None: + print_rank_0("The tokenizer has no pad_token_id, using eos_token_id instead.") + self.tokenizer.pad_token_id = self.tokenizer.eos_token_id + + if hasattr(self.tokenizer, "pad_token") and self.tokenizer.pad_token is None: + if self.tokenizer.eos_token == "<|eot_id|>": # nosec + self.tokenizer.pad_token = "<|end_of_text|>" # nosec + else: + raise ValueError("The tokenizer has no pad_token!") + + def _post_process_chat_template(self): + # [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the + # tokens are preserved for supervised learning. + self.tokenizer.chat_template = self.tokenizer.chat_template.replace( + REMOVE_THINK_CHAT_TEMPLATE, "" + ) + + def _process_chat_sample(self, examples: list): + tokenized_examples = self.tokenizer.apply_chat_template( + examples, + return_tensors="pt", + return_dict=True, + padding="max_length", + truncation=True, + max_length=self.train_len, + add_generation_prompt=self.add_generation_prompt, + return_assistant_tokens_mask=self.answer_only_loss, + ) + if self.return_labels: + input_ids = tokenized_examples["input_ids"] + labels = input_ids.new_full(input_ids.shape, IGNORE_TOKEN_ID) + labels[..., :-1] = input_ids[..., 1:] + tokenized_examples["labels"] = labels + return tokenized_examples + + def _process_text_sample(self, examples: list): + tokenized_examples = self.tokenizer( + examples, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.train_len, + ) + return tokenized_examples + + def __call__(self, examples): + """Call the LanguageDataCollator.""" + batch = [] + + for example in examples: + if not isinstance(example, dict): + raise ValueError("The sample must be a Dict but got {}".format(type(example))) + text = example.get(self.json_key, None) + if isinstance(text, str): + batch.append(text) + else: + messages = example.get("messages", None) + if messages is None: + conversations = example.get("conversations", None) + if conversations is None: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) + else: + messages = _sharegpt_to_openai_messages(conversations) + batch.append(messages) + + return self._process_chat_sample(batch) + + +class VisionLanguageDataCollator(LanguageDataCollator): + """VisionLanguageDataCollator is a subclass of LanguageDataCollator that is used to collate vision-language data.""" + + def __init__( + self, + processor: str, + train_len: int = 8192, + chat_template: str | None = None, + add_generation_prompt: bool = False, + answer_only_loss: bool = False, + local_image_path: str = "", + return_labels: bool = False, + ): + """Initialize the VisionLanguageDataset.""" + self.processor = transformers.AutoProcessor.from_pretrained(processor) + self.chat_template = chat_template + self.local_image_path = local_image_path + + super().__init__( + tokenizer=self.processor.tokenizer, + train_len=train_len, + chat_template=chat_template, + add_generation_prompt=add_generation_prompt, + answer_only_loss=answer_only_loss, + return_labels=return_labels, + ) + + def _process_multimodal_sample(self, examples): + tokenized_messages = self.processor.apply_chat_template( + examples, + tokenize=True, + return_tensors="pt", + return_dict=True, + padding="max_length", + truncation=True, + max_length=self.train_len, + add_generation_prompt=self.add_generation_prompt, + return_assistant_tokens_mask=self.answer_only_loss, + ) + + return tokenized_messages + + def __call__(self, examples): + """Call the VisionLanguageDataCollator.""" + batch = [] + + for example in examples: + messages = example.get("messages", None) + if messages is None: + conversations = example.get("conversations", None) + if conversations is None: + raise ValueError( + "The sample must in either OpenAI messages format or ShareGPT conversations format." + ) + else: + messages = _sharegpt_to_openai_messages(conversations) + + copy_messages = copy.deepcopy(messages) + + for msg in copy_messages: + if isinstance(msg["content"], str): + msg["content"] = [{"type": "text", "text": msg["content"]}] + + for ctn in msg["content"]: + if ctn["type"] == "image" and "image" in ctn: + ctn["image"] = os.path.abspath( + os.path.join(self.local_image_path, ctn["image"]) + ) + # If any value in ctn is None, delete that key + # HF dataloader add Nones to align keys. Leads to error in processor. + keys_to_delete = [k for k, v in ctn.items() if v is None] + for k in keys_to_delete: + del ctn[k] + + batch.append(copy_messages) + + return self._process_multimodal_sample(batch) diff --git a/tests/examples/speculative_decoding/conftest.py b/tests/examples/speculative_decoding/conftest.py index bc75b8783..80417f404 100644 --- a/tests/examples/speculative_decoding/conftest.py +++ b/tests/examples/speculative_decoding/conftest.py @@ -21,18 +21,20 @@ @pytest.fixture(scope="session", autouse=True) def tiny_daring_anteater_path(tmp_path_factory): - dataset_path = MODELOPT_ROOT / "examples/speculative_decoding/Daring-Anteater" + dataset_path = ( + MODELOPT_ROOT / "examples/speculative_decoding/input_conversations/daring-anteater.jsonl" + ) if not os.path.exists(dataset_path): try: run_example_command( - ["git", "clone", "https://huggingface.co/datasets/nvidia/Daring-Anteater"], + ["python", "prepare_input_conversations/add_daring_anteater.py"], "speculative_decoding", ) except Exception as e: # Ignore rate-limiting errors - pytest.skip(f"Failed to clone Daring-Anteater dataset: {e}") + pytest.skip(f"Failed to prepare dataset: {e}") output_path = tmp_path_factory.mktemp("daring_anteater") / "train.jsonl" - with open(dataset_path / "train.jsonl") as src, open(output_path, "w") as dst: + with open(dataset_path) as src, open(output_path, "w") as dst: for i, line in enumerate(src): if i >= 128: break diff --git a/tests/examples/speculative_decoding/test_eagle.py b/tests/examples/speculative_decoding/test_eagle.py index 3775b8a4c..86da7ef48 100644 --- a/tests/examples/speculative_decoding/test_eagle.py +++ b/tests/examples/speculative_decoding/test_eagle.py @@ -30,6 +30,31 @@ def eagle_output_dir(tmp_path_factory): return tmp_path_factory.mktemp("eagle_output_dir") +@pytest.fixture(scope="module") +def draft_vocab_cache_dir(tmp_path_factory): + """Eagle output directory shared in this module.""" + return tmp_path_factory.mktemp("eagle_output_dir") + + +def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path, draft_vocab_cache_dir): + """Test calibration of draft vocabulary.""" + run_example_command( + [ + "python", + "./scripts/calibrate_draft_vocab.py", + "--model", + tiny_llama_path, + "--data", + tiny_daring_anteater_path, + "--draft_vocab_size", + "100", + "--save_dir", + draft_vocab_cache_dir, + ], + "speculative_decoding", + ) + + # fmt: off @pytest.mark.parametrize("cp_size", [1, 2]) def test_llama_eagle3(tiny_llama_path, tiny_daring_anteater_path, tmp_path, eagle_output_dir, cp_size): @@ -112,17 +137,3 @@ def test_convert_to_vllm_ckpt(tiny_llama_path, eagle_output_dir): ], "speculative_decoding", ) - -@pytest.mark.skip(reason="Needs dataset conversion to role-content format; consolidate data loading first.") -def test_calibrate_draft_vocab(tiny_llama_path, tiny_daring_anteater_path,tmp_path): - """Test calibration of draft vocabulary.""" - run_example_command( - [ - "python", "./scripts/calibrate_draft_vocab.py", - "--model", tiny_llama_path, - "--data", tiny_daring_anteater_path, - "--draft_vocab_size", "100", - "--save_dir", tmp_path / "draft_vocab_cache", - ], - "speculative_decoding", - )