Fix tokenizer saving during training with Trainer (#12806)

* add test in trainer and test tokenizer saving wi
th trainer

* quality

* reverse trainer changes

* replace test in test_trainer by a test for all the tokenizers

* format

* add can_save_slow_tokenizer attribute to all tokenizers

* fix Herbert

* format

* Change comment in error

* add comments and a new assert

* Update src/transformers/models/albert/tokenization_albert_fast.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* change ValueError barthez

* change ValueError BigBird

* change ValueError Camembert

* change ValueError Mbart50

* change ValueError Pegasus

* change ValueError ReFormer

* change ValueError T5

* change ValueError RoBERTa

* XLNET fast

* Update tests/test_tokenization_common.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* change `assert` into `self.assertIn`

* format

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
SaulLu 2021-09-01 16:32:56 +02:00 committed by GitHub
parent c1b20e42f5
commit c4d78f01de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 104 additions and 5 deletions

View File

@ -158,6 +158,7 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast):
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
@ -216,6 +217,12 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast):
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -137,6 +137,7 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
@ -187,6 +188,12 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -138,6 +138,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
@ -227,6 +228,12 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -135,6 +135,7 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
@ -186,6 +187,12 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -22,10 +22,7 @@ from .tokenization_herbert import HerbertTokenizer
logger = logging.get_logger(__name__)
VOCAB_FILES_NAMES = {
"vocab_file": "vocab.json",
"merges_file": "merges.txt",
}
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
PRETRAINED_VOCAB_FILES_MAP = {
"vocab_file": {

View File

@ -145,6 +145,7 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
self.lang_code_to_id = {
lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES
@ -258,6 +259,12 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
return inputs
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -148,6 +148,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
**kwargs,
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def _special_token_mask(self, seq):
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
@ -192,6 +193,12 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
return token_ids_0 + token_ids_1 + [self.eos_token_id]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -104,8 +104,15 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -137,9 +137,16 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
self._extra_ids = extra_ids
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -145,6 +145,7 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
)
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
@ -198,6 +199,12 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
return

View File

@ -164,6 +164,7 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
self.remove_space = remove_space
self.keep_accents = keep_accents
self.vocab_file = vocab_file
self.can_save_slow_tokenizer = False if not self.vocab_file else True
def build_inputs_with_special_tokens(
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
@ -222,6 +223,12 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
if not self.can_save_slow_tokenizer:
raise ValueError(
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
"tokenizer."
)
if not os.path.isdir(save_directory):
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
return

View File

@ -87,6 +87,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
"""
slow_tokenizer_class: PreTrainedTokenizer = None
can_save_slow_tokenizer: bool = True
def __init__(self, *args, **kwargs):
tokenizer_object = kwargs.pop("tokenizer_object", None)
@ -551,7 +552,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
"might consider leaving the legacy_format at `None` or setting it to `False`."
)
save_slow = (legacy_format is None or legacy_format is True) and self.slow_tokenizer_class is not None
save_slow = (
(legacy_format is None or legacy_format is True)
and self.slow_tokenizer_class is not None
and self.can_save_slow_tokenizer
)
save_fast = legacy_format is None or legacy_format is False
if save_slow:

View File

@ -38,6 +38,8 @@ from transformers import (
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
SpecialTokensMixin,
Trainer,
TrainingArguments,
is_tf_available,
is_torch_available,
)
@ -56,6 +58,10 @@ from transformers.testing_utils import (
from transformers.tokenization_utils import AddedToken
if is_torch_available():
import torch.nn as nn
if TYPE_CHECKING:
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
@ -3389,6 +3395,27 @@ class TokenizerTesterMixin:
)
)
@require_torch
def test_saving_tokenizer_trainer(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
with tempfile.TemporaryDirectory() as tmp_dir:
# Save the fast tokenizer files in a temporary directory
tokenizer_old = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs, use_fast=True)
tokenizer_old.save_pretrained(tmp_dir, legacy_format=False) # save only fast version
# Initialize toy model for the trainer
model = nn.Module()
# Load tokenizer from a folder without legacy files
tokenizer = self.rust_tokenizer_class.from_pretrained(tmp_dir)
training_args = TrainingArguments(output_dir=tmp_dir, do_train=True, no_cuda=True)
trainer = Trainer(model=model, args=training_args, tokenizer=tokenizer)
# Should not raise an error
trainer.save_model(os.path.join(tmp_dir, "checkpoint"))
self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint")))
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):