From a5eb3139a35a1ee480dd91095d120c2e978968a1 Mon Sep 17 00:00:00 2001 From: michaelfeil <63565275+michaelfeil@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:06:48 -0800 Subject: [PATCH 1/3] third-party-dataset support, for shuffle Signed-off-by: michaelfeil <63565275+michaelfeil@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 69 ++++++++++++++++++++++++++- 1 file changed, 67 insertions(+), 2 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 16bff49c2..dbb73d515 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -110,6 +110,69 @@ ] +def _third_party_get_dataset_samples( + dataset_name: str, num_samples: int, tokenizer: "PreTrainedTokenizerBase | None" +) -> list[str]: + """Load a third-party dataset with the given name and number of samples. + + for messages: apply_chat_template is applied as needed. + for text: no tokenization is done and plain text is still returned. + """ + warn( + f"Loading third-party datset {dataset_name} with the split `train`, as the dataset is not registered in {get_supported_datasets()}." + ) + from datasets import load_dataset + + dataset = load_dataset( + dataset_name, + streaming=True, + split="train", + ) + dataset = dataset.shuffle(seed=42, buffer_size=10000).take(num_samples) + texts = [] + if "messages" in dataset.column_names: + if tokenizer is None: + raise ValueError( + f"Your dataset {dataset_name} has a `messages` column, but no tokenizer was provided. Are you sure you are using a tokenizer that supports chat templates?" + ) + if not hasattr(tokenizer, "apply_chat_template"): + raise ValueError( + f"Your dataset {dataset_name} has a `messages` column, but the tokenizer does not have an `apply_chat_template` method. Are you sure you are using a tokenizer that supports chat templates?" + ) + texts = [] + print( + f"Using dataset with columns of {dataset_name}: messages and tools to apply chat template." + ) + for i, sample in enumerate(dataset): + messages = sample.get("messages", []) + kwargs = {} + tools = sample.get("tools", []) + if tools: + kwargs["tools"] = tools + if not messages: + raise ValueError( + f"Column {i} in dataset {dataset_name} has no messages, or a empty messages." + ) + text: str = tokenizer.apply_chat_template(messages, **kwargs, tokenize=False) + if len(text) == 0: + raise ValueError( + f"Column {i} in dataset {dataset_name} has empty text after applying chat template." + ) + texts.append(text) + elif "prompt" in dataset.column_names: + texts = [sample["prompt"] for sample in dataset] + elif "text" in dataset.column_names: + texts = [sample["text"] for sample in dataset] + else: + raise NotImplementedError( + f"Dataset {dataset_name} is not supported. Please use one of the following: {get_supported_datasets()}. " + " For supporting thrid-party datasets, your dataset must have either a `messages` or `prompt` column, and a `train` split." + " For example the `baseten/quant_calibration_dataset_v1` dataset has a `messages` column and a `train` split." + ) + + return texts + + def get_dataset_samples( dataset_name: str, num_samples: int, @@ -131,10 +194,12 @@ def get_dataset_samples( """ # Load the dataset if dataset_name not in SUPPORTED_DATASET_CONFIG: - raise NotImplementedError( + warn( f"dataset {dataset_name} is not supported. Please use one of the following:" f" {get_supported_datasets()}." + " Trying to set up via third-party datasets." ) + return _third_party_get_dataset_samples(dataset_name, num_samples, tokenizer=tokenizer) from datasets import load_dataset @@ -244,7 +309,7 @@ def get_dataset_dataloader( all_samples = [] for ds_name, num_sample in zip(dataset_name, num_samples): - samples = get_dataset_samples(ds_name, num_sample) + samples = get_dataset_samples(ds_name, num_sample, tokenizer=tokenizer) all_samples.extend(samples) batch_encoded = tokenizer.batch_encode_plus( From 00ee9335373716c4b3ca9f716d192dd5a245206f Mon Sep 17 00:00:00 2001 From: michaelfeil <63565275+michaelfeil@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:13:06 -0800 Subject: [PATCH 2/3] add suggestions from coderabbit Signed-off-by: michaelfeil <63565275+michaelfeil@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index dbb73d515..a2c22f9e5 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -151,12 +151,12 @@ def _third_party_get_dataset_samples( kwargs["tools"] = tools if not messages: raise ValueError( - f"Column {i} in dataset {dataset_name} has no messages, or a empty messages." + f"Row {i} in dataset {dataset_name} has no messages, or a empty messages." ) text: str = tokenizer.apply_chat_template(messages, **kwargs, tokenize=False) if len(text) == 0: raise ValueError( - f"Column {i} in dataset {dataset_name} has empty text after applying chat template." + f"Row {i} in dataset {dataset_name} has empty text after applying chat template." ) texts.append(text) elif "prompt" in dataset.column_names: @@ -166,7 +166,7 @@ def _third_party_get_dataset_samples( else: raise NotImplementedError( f"Dataset {dataset_name} is not supported. Please use one of the following: {get_supported_datasets()}. " - " For supporting thrid-party datasets, your dataset must have either a `messages` or `prompt` column, and a `train` split." + " For supporting third-party datasets, your dataset must have either a `messages` or `prompt` column, and a `train` split." " For example the `baseten/quant_calibration_dataset_v1` dataset has a `messages` column and a `train` split." ) From 560b5033a584f1376db037ebc45a0dec83017d5d Mon Sep 17 00:00:00 2001 From: michaelfeil <63565275+michaelfeil@users.noreply.github.com> Date: Tue, 3 Feb 2026 15:30:31 -0800 Subject: [PATCH 3/3] add suggestions from coderabbit Signed-off-by: michaelfeil <63565275+michaelfeil@users.noreply.github.com> --- modelopt/torch/utils/dataset_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index a2c22f9e5..9a576a770 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -119,7 +119,7 @@ def _third_party_get_dataset_samples( for text: no tokenization is done and plain text is still returned. """ warn( - f"Loading third-party datset {dataset_name} with the split `train`, as the dataset is not registered in {get_supported_datasets()}." + f"Loading third-party dataset {dataset_name} with the split `train`, as the dataset is not registered in {get_supported_datasets()}." ) from datasets import load_dataset