Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 67 additions & 2 deletions modelopt/torch/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,69 @@
]


def _third_party_get_dataset_samples(
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for contributing!

I think there are 2 things being done here which we could simplify further:

  1. Support tools in tokenizer.apply_chat_template which we can add to the existing get_dataset_samples() with optional tools_key in the dataset config in SUPPORTED_DATASET_CONFIG
  2. Custom dataset - I think we just need to import and overwrite SUPPORTED_DATASET_CONFIG from this file then we dont need separate function for third party datasets

Copy link
Contributor Author

@michaelfeil michaelfeil Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing my PR again and fast response.

We have like 30+ datasets, some are even from customers. How can we make sure this is added? Do i need to register them in the config for each one of them? The goal of this PR is to allow datasets that are not in SUPPORTED_DATASET_CONFIG, and just duck-type the datasets.

https://huggingface.co/datasets/baseten/ptq-on-policy-nemotron-KimiK2.5
https://huggingface.co/datasets/baseten/ptq-on-policy-nemotron-GLM-4.7
https://huggingface.co/datasets/baseten/ptq-on-policy-nemotron-Kimi-K2-v2
https://huggingface.co/datasets/baseten/quant_calibration_dataset_v1

I think it would be best supported with a separate function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Motivation behind this (See Baseten-nvidia channel) https://basetenlabs.slack.com/archives/C04BUDD86FR/p1770160944160069?thread_ts=1770160618.661969&cid=C04BUDD86FR.

We essentially never want a 'cnn_dailymail` not available scenario to impact customers on baseten again. SUPPORTED_DATASET_CONFIG has brought us a ton of pain, and monkey patching a local variable in global import + trt mpi magic is not working well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After the incident with modelopt, I am super biased, but i think SUPPORTED_DATASET_CONFIG is just not a good concept, and the concatination of the user-strings via \n was a quality issue in the ptq version of the model (obviosuly needs to follow the exact chat template tokens for outlier free amax)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kevalmorabia97 Is there an update here?

dataset_name: str, num_samples: int, tokenizer: "PreTrainedTokenizerBase | None"
) -> list[str]:
"""Load a third-party dataset with the given name and number of samples.

for messages: apply_chat_template is applied as needed.
for text: no tokenization is done and plain text is still returned.
"""
warn(
f"Loading third-party dataset {dataset_name} with the split `train`, as the dataset is not registered in {get_supported_datasets()}."
)
from datasets import load_dataset

dataset = load_dataset(
dataset_name,
streaming=True,
split="train",
)
dataset = dataset.shuffle(seed=42, buffer_size=10000).take(num_samples)
texts = []
if "messages" in dataset.column_names:
if tokenizer is None:
raise ValueError(
f"Your dataset {dataset_name} has a `messages` column, but no tokenizer was provided. Are you sure you are using a tokenizer that supports chat templates?"
)
if not hasattr(tokenizer, "apply_chat_template"):
raise ValueError(
f"Your dataset {dataset_name} has a `messages` column, but the tokenizer does not have an `apply_chat_template` method. Are you sure you are using a tokenizer that supports chat templates?"
)
texts = []
print(
f"Using dataset with columns of {dataset_name}: messages and tools to apply chat template."
)
for i, sample in enumerate(dataset):
messages = sample.get("messages", [])
kwargs = {}
tools = sample.get("tools", [])
if tools:
kwargs["tools"] = tools
if not messages:
raise ValueError(
f"Row {i} in dataset {dataset_name} has no messages, or a empty messages."
)
text: str = tokenizer.apply_chat_template(messages, **kwargs, tokenize=False)
if len(text) == 0:
raise ValueError(
f"Row {i} in dataset {dataset_name} has empty text after applying chat template."
)
texts.append(text)
elif "prompt" in dataset.column_names:
texts = [sample["prompt"] for sample in dataset]
elif "text" in dataset.column_names:
texts = [sample["text"] for sample in dataset]
else:
raise NotImplementedError(
f"Dataset {dataset_name} is not supported. Please use one of the following: {get_supported_datasets()}. "
" For supporting third-party datasets, your dataset must have either a `messages` or `prompt` column, and a `train` split."
" For example the `baseten/quant_calibration_dataset_v1` dataset has a `messages` column and a `train` split."
)

return texts


def get_dataset_samples(
dataset_name: str,
num_samples: int,
Expand All @@ -131,10 +194,12 @@ def get_dataset_samples(
"""
# Load the dataset
if dataset_name not in SUPPORTED_DATASET_CONFIG:
raise NotImplementedError(
warn(
f"dataset {dataset_name} is not supported. Please use one of the following:"
f" {get_supported_datasets()}."
" Trying to set up via third-party datasets."
)
return _third_party_get_dataset_samples(dataset_name, num_samples, tokenizer=tokenizer)

from datasets import load_dataset

Expand Down Expand Up @@ -244,7 +309,7 @@ def get_dataset_dataloader(

all_samples = []
for ds_name, num_sample in zip(dataset_name, num_samples):
samples = get_dataset_samples(ds_name, num_sample)
samples = get_dataset_samples(ds_name, num_sample, tokenizer=tokenizer)
all_samples.extend(samples)

batch_encoded = tokenizer.batch_encode_plus(
Expand Down