diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py index 7ad36ab017f7..fc31e72984f4 100644 --- a/src/transformers/tokenization_utils_base.py +++ b/src/transformers/tokenization_utils_base.py @@ -2472,8 +2472,8 @@ def save_pretrained( # Add tokenizer class to the tokenizer config to be able to reload it with from_pretrained tokenizer_class = self.__class__.__name__ - # Remove the Fast at the end unless we have a special `PreTrainedTokenizerFast` - if tokenizer_class.endswith("Fast") and tokenizer_class != "PreTrainedTokenizerFast": + # Remove the Fast at the end if we can save the slow tokenizer + if tokenizer_class.endswith("Fast") and getattr(self, "can_save_slow_tokenizer", False): tokenizer_class = tokenizer_class[:-4] tokenizer_config["tokenizer_class"] = tokenizer_class if getattr(self, "_auto_map", None) is not None: diff --git a/tests/tokenization/test_tokenization_fast.py b/tests/tokenization/test_tokenization_fast.py index d5c6444de4ec..4bd9b046d406 100644 --- a/tests/tokenization/test_tokenization_fast.py +++ b/tests/tokenization/test_tokenization_fast.py @@ -20,7 +20,7 @@ import tempfile import unittest -from transformers import AutoTokenizer, PreTrainedTokenizerFast +from transformers import AutoTokenizer, LlamaTokenizerFast, PreTrainedTokenizerFast from transformers.testing_utils import require_tokenizers from ..test_tokenization_common import TokenizerTesterMixin @@ -170,6 +170,41 @@ def test_init_from_tokenizers_model(self): # thus tok(sentences, truncation = True) does nothing and does not warn either self.assertEqual(tok(sentences, truncation = True, max_length = 8), {'input_ids': [[8774, 6, 3, 63, 31, 1748, 55, 1],[ 571, 33, 25, 3, 2, 3, 58, 1]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0],[0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1],[1, 1, 1, 1, 1, 1, 1, 1]]}) # fmt: skip + def test_class_after_save_and_reload(self): + # Model contains a `LlamaTokenizerFast` tokenizer with no slow fallback + model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B" + + with tempfile.TemporaryDirectory() as temp_dir: + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True) + self.assertTrue( + isinstance(tokenizer, LlamaTokenizerFast), + f"Expected tokenizer(use_fast=True) type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`", + ) + + # Fast tokenizer will ignore `use_fast=False` + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False) + self.assertTrue( + isinstance(tokenizer, LlamaTokenizerFast), + f"Expected tokenizer type(use_fast=False): `LlamaTokenizerFast`, actual=`{type(tokenizer)}`", + ) + + # Save tokenizer + tokenizer.save_pretrained(temp_dir) + + tokenizer = AutoTokenizer.from_pretrained(temp_dir, use_fast=False) + # Verify post save and reload the fast tokenizer class did not change + self.assertTrue( + isinstance(tokenizer, LlamaTokenizerFast), + f"Expected tokenizer type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`", + ) + + tokenizer = AutoTokenizer.from_pretrained(temp_dir, use_fast=True) + # Verify post save and reload the fast tokenizer class did not change + self.assertTrue( + isinstance(tokenizer, LlamaTokenizerFast), + f"Expected tokenizer type: `LlamaTokenizerFast`, actual=`{type(tokenizer)}`", + ) + @require_tokenizers class TokenizerVersioningTest(unittest.TestCase):