diff --git a/modelopt/torch/utils/dataset_utils.py b/modelopt/torch/utils/dataset_utils.py index 16bff49c2..e154de36e 100644 --- a/modelopt/torch/utils/dataset_utils.py +++ b/modelopt/torch/utils/dataset_utils.py @@ -247,7 +247,7 @@ def get_dataset_dataloader( samples = get_dataset_samples(ds_name, num_sample) all_samples.extend(samples) - batch_encoded = tokenizer.batch_encode_plus( + batch_encoded = tokenizer( all_samples, return_tensors="pt", padding=True,