diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 16bff49c2..9a576a770 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -110,6 +110,69 @@ ] +def _third_party_get_dataset_samples( + dataset_name: str, num_samples: int, tokenizer: "PreTrainedTokenizerBase | None" +) -> list[str]: + """Load a third-party dataset with the given name and number of samples. + + for messages: apply_chat_template is applied as needed. + for text: no tokenization is done and plain text is still returned. + """ + warn( + f"Loading third-party dataset {dataset_name} with the split `train`, as the dataset is not registered in {get_supported_datasets()}." + ) + from datasets import load_dataset + + dataset = load_dataset( + dataset_name, + streaming=True, + split="train", + ) + dataset = dataset.shuffle(seed=42, buffer_size=10000).take(num_samples) + texts = [] + if "messages" in dataset.column_names: + if tokenizer is None: + raise ValueError( + f"Your dataset {dataset_name} has a `messages` column, but no tokenizer was provided. Are you sure you are using a tokenizer that supports chat templates?" + ) + if not hasattr(tokenizer, "apply_chat_template"): + raise ValueError( + f"Your dataset {dataset_name} has a `messages` column, but the tokenizer does not have an `apply_chat_template` method. Are you sure you are using a tokenizer that supports chat templates?" + ) + texts = [] + print( + f"Using dataset with columns of {dataset_name}: messages and tools to apply chat template." + ) + for i, sample in enumerate(dataset): + messages = sample.get("messages", []) + kwargs = {} + tools = sample.get("tools", []) + if tools: + kwargs["tools"] = tools + if not messages: + raise ValueError( + f"Row {i} in dataset {dataset_name} has no messages, or a empty messages." + ) + text: str = tokenizer.apply_chat_template(messages, **kwargs, tokenize=False) + if len(text) == 0: + raise ValueError( + f"Row {i} in dataset {dataset_name} has empty text after applying chat template." + ) + texts.append(text) + elif "prompt" in dataset.column_names: + texts = [sample["prompt"] for sample in dataset] + elif "text" in dataset.column_names: + texts = [sample["text"] for sample in dataset] + else: + raise NotImplementedError( + f"Dataset {dataset_name} is not supported. Please use one of the following: {get_supported_datasets()}. " + " For supporting third-party datasets, your dataset must have either a `messages` or `prompt` column, and a `train` split." + " For example the `baseten/quant_calibration_dataset_v1` dataset has a `messages` column and a `train` split." + ) + + return texts + + def get_dataset_samples( dataset_name: str, num_samples: int, @@ -131,10 +194,12 @@ def get_dataset_samples( """ # Load the dataset if dataset_name not in SUPPORTED_DATASET_CONFIG: - raise NotImplementedError( + warn( f"dataset {dataset_name} is not supported. Please use one of the following:" f" {get_supported_datasets()}." + " Trying to set up via third-party datasets." ) + return _third_party_get_dataset_samples(dataset_name, num_samples, tokenizer=tokenizer) from datasets import load_dataset @@ -244,7 +309,7 @@ def get_dataset_dataloader( all_samples = [] for ds_name, num_sample in zip(dataset_name, num_samples): - samples = get_dataset_samples(ds_name, num_sample) + samples = get_dataset_samples(ds_name, num_sample, tokenizer=tokenizer) all_samples.extend(samples) batch_encoded = tokenizer.batch_encode_plus(