From c4d78f01de30152f0e7a6f39d0901b85b0fd422d Mon Sep 17 00:00:00 2001 From: SaulLu <55560583+SaulLu@users.noreply.github.com> Date: Wed, 1 Sep 2021 16:32:56 +0200 Subject: [PATCH] Fix tokenizer saving during training with `Trainer` (#12806) * add test in trainer and test tokenizer saving wi th trainer * quality * reverse trainer changes * replace test in test_trainer by a test for all the tokenizers * format * add can_save_slow_tokenizer attribute to all tokenizers * fix Herbert * format * Change comment in error * add comments and a new assert * Update src/transformers/models/albert/tokenization_albert_fast.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change ValueError barthez * change ValueError BigBird * change ValueError Camembert * change ValueError Mbart50 * change ValueError Pegasus * change ValueError ReFormer * change ValueError T5 * change ValueError RoBERTa * XLNET fast * Update tests/test_tokenization_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change `assert` into `self.assertIn` * format Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> --- .../models/albert/tokenization_albert_fast.py | 7 +++++ .../barthez/tokenization_barthez_fast.py | 7 +++++ .../big_bird/tokenization_big_bird_fast.py | 7 +++++ .../camembert/tokenization_camembert_fast.py | 7 +++++ .../herbert/tokenization_herbert_fast.py | 5 +--- .../mbart50/tokenization_mbart50_fast.py | 7 +++++ .../pegasus/tokenization_pegasus_fast.py | 7 +++++ .../reformer/tokenization_reformer_fast.py | 7 +++++ .../models/t5/tokenization_t5_fast.py | 7 +++++ .../tokenization_xlm_roberta_fast.py | 7 +++++ .../models/xlnet/tokenization_xlnet_fast.py | 7 +++++ src/transformers/tokenization_utils_fast.py | 7 ++++- tests/test_tokenization_common.py | 27 +++++++++++++++++++ 13 files changed, 104 insertions(+), 5 deletions(-) diff --git a/src/transformers/models/albert/tokenization_albert_fast.py b/src/transformers/models/albert/tokenization_albert_fast.py index 44e4a3f7355..60c9d3144f1 100644 --- a/src/transformers/models/albert/tokenization_albert_fast.py +++ b/src/transformers/models/albert/tokenization_albert_fast.py @@ -158,6 +158,7 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast): self.remove_space = remove_space self.keep_accents = keep_accents self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None @@ -216,6 +217,12 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast): return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/models/barthez/tokenization_barthez_fast.py b/src/transformers/models/barthez/tokenization_barthez_fast.py index 41e1bae911b..a66f5936a9f 100644 --- a/src/transformers/models/barthez/tokenization_barthez_fast.py +++ b/src/transformers/models/barthez/tokenization_barthez_fast.py @@ -137,6 +137,7 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast): ) self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None @@ -187,6 +188,12 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast): return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/models/big_bird/tokenization_big_bird_fast.py b/src/transformers/models/big_bird/tokenization_big_bird_fast.py index d1c33aaf26a..36f2afa3373 100644 --- a/src/transformers/models/big_bird/tokenization_big_bird_fast.py +++ b/src/transformers/models/big_bird/tokenization_big_bird_fast.py @@ -138,6 +138,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast): ) self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None @@ -227,6 +228,12 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast): return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/models/camembert/tokenization_camembert_fast.py b/src/transformers/models/camembert/tokenization_camembert_fast.py index c2da521d8b8..cce7e2f63cb 100644 --- a/src/transformers/models/camembert/tokenization_camembert_fast.py +++ b/src/transformers/models/camembert/tokenization_camembert_fast.py @@ -135,6 +135,7 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast): ) self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None @@ -186,6 +187,12 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast): return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/models/herbert/tokenization_herbert_fast.py b/src/transformers/models/herbert/tokenization_herbert_fast.py index beff50eaa86..2961d5c94cf 100644 --- a/src/transformers/models/herbert/tokenization_herbert_fast.py +++ b/src/transformers/models/herbert/tokenization_herbert_fast.py @@ -22,10 +22,7 @@ from .tokenization_herbert import HerbertTokenizer logger = logging.get_logger(__name__) -VOCAB_FILES_NAMES = { - "vocab_file": "vocab.json", - "merges_file": "merges.txt", -} +VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"} PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { diff --git a/src/transformers/models/mbart50/tokenization_mbart50_fast.py b/src/transformers/models/mbart50/tokenization_mbart50_fast.py index b3966f9c0b1..93f93d2423c 100644 --- a/src/transformers/models/mbart50/tokenization_mbart50_fast.py +++ b/src/transformers/models/mbart50/tokenization_mbart50_fast.py @@ -145,6 +145,7 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): ) self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True self.lang_code_to_id = { lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES @@ -258,6 +259,12 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast): return inputs def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/models/pegasus/tokenization_pegasus_fast.py b/src/transformers/models/pegasus/tokenization_pegasus_fast.py index c1c48c5cfbf..21c77594ead 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus_fast.py +++ b/src/transformers/models/pegasus/tokenization_pegasus_fast.py @@ -148,6 +148,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): **kwargs, ) self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True def _special_token_mask(self, seq): all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp @@ -192,6 +193,12 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast): return token_ids_0 + token_ids_1 + [self.eos_token_id] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/models/reformer/tokenization_reformer_fast.py b/src/transformers/models/reformer/tokenization_reformer_fast.py index 1e080478347..3fc8583c81a 100644 --- a/src/transformers/models/reformer/tokenization_reformer_fast.py +++ b/src/transformers/models/reformer/tokenization_reformer_fast.py @@ -104,8 +104,15 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast): ) self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/models/t5/tokenization_t5_fast.py b/src/transformers/models/t5/tokenization_t5_fast.py index 3f972b006b7..faf8681071e 100644 --- a/src/transformers/models/t5/tokenization_t5_fast.py +++ b/src/transformers/models/t5/tokenization_t5_fast.py @@ -137,9 +137,16 @@ class T5TokenizerFast(PreTrainedTokenizerFast): ) self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True self._extra_ids = extra_ids def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py index 6c7f63f6b02..3c686110fd0 100644 --- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py +++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta_fast.py @@ -145,6 +145,7 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast): ) self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None @@ -198,6 +199,12 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast): return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory.") return diff --git a/src/transformers/models/xlnet/tokenization_xlnet_fast.py b/src/transformers/models/xlnet/tokenization_xlnet_fast.py index d47827d9029..8b72c8def86 100644 --- a/src/transformers/models/xlnet/tokenization_xlnet_fast.py +++ b/src/transformers/models/xlnet/tokenization_xlnet_fast.py @@ -164,6 +164,7 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast): self.remove_space = remove_space self.keep_accents = keep_accents self.vocab_file = vocab_file + self.can_save_slow_tokenizer = False if not self.vocab_file else True def build_inputs_with_special_tokens( self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None @@ -222,6 +223,12 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast): return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: + if not self.can_save_slow_tokenizer: + raise ValueError( + "Your fast tokenizer does not have the necessary information to save the vocabulary for a slow " + "tokenizer." + ) + if not os.path.isdir(save_directory): logger.error(f"Vocabulary path ({save_directory}) should be a directory") return diff --git a/src/transformers/tokenization_utils_fast.py b/src/transformers/tokenization_utils_fast.py index b37539bb4fc..ea1ed11f6ae 100644 --- a/src/transformers/tokenization_utils_fast.py +++ b/src/transformers/tokenization_utils_fast.py @@ -87,6 +87,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): """ slow_tokenizer_class: PreTrainedTokenizer = None + can_save_slow_tokenizer: bool = True def __init__(self, *args, **kwargs): tokenizer_object = kwargs.pop("tokenizer_object", None) @@ -551,7 +552,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase): "might consider leaving the legacy_format at `None` or setting it to `False`." ) - save_slow = (legacy_format is None or legacy_format is True) and self.slow_tokenizer_class is not None + save_slow = ( + (legacy_format is None or legacy_format is True) + and self.slow_tokenizer_class is not None + and self.can_save_slow_tokenizer + ) save_fast = legacy_format is None or legacy_format is False if save_slow: diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 443fc402cfe..6e35828dc4c 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -38,6 +38,8 @@ from transformers import ( PreTrainedTokenizerBase, PreTrainedTokenizerFast, SpecialTokensMixin, + Trainer, + TrainingArguments, is_tf_available, is_torch_available, ) @@ -56,6 +58,10 @@ from transformers.testing_utils import ( from transformers.tokenization_utils import AddedToken +if is_torch_available(): + import torch.nn as nn + + if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel @@ -3389,6 +3395,27 @@ class TokenizerTesterMixin: ) ) + @require_torch + def test_saving_tokenizer_trainer(self): + for tokenizer, pretrained_name, kwargs in self.tokenizers_list: + with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"): + with tempfile.TemporaryDirectory() as tmp_dir: + # Save the fast tokenizer files in a temporary directory + tokenizer_old = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs, use_fast=True) + tokenizer_old.save_pretrained(tmp_dir, legacy_format=False) # save only fast version + + # Initialize toy model for the trainer + model = nn.Module() + + # Load tokenizer from a folder without legacy files + tokenizer = self.rust_tokenizer_class.from_pretrained(tmp_dir) + training_args = TrainingArguments(output_dir=tmp_dir, do_train=True, no_cuda=True) + trainer = Trainer(model=model, args=training_args, tokenizer=tokenizer) + + # Should not raise an error + trainer.save_model(os.path.join(tmp_dir, "checkpoint")) + self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint"))) + @is_staging_test class TokenizerPushToHubTester(unittest.TestCase):