mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix tokenizer saving during training with Trainer
(#12806)
* add test in trainer and test tokenizer saving wi th trainer * quality * reverse trainer changes * replace test in test_trainer by a test for all the tokenizers * format * add can_save_slow_tokenizer attribute to all tokenizers * fix Herbert * format * Change comment in error * add comments and a new assert * Update src/transformers/models/albert/tokenization_albert_fast.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change ValueError barthez * change ValueError BigBird * change ValueError Camembert * change ValueError Mbart50 * change ValueError Pegasus * change ValueError ReFormer * change ValueError T5 * change ValueError RoBERTa * XLNET fast * Update tests/test_tokenization_common.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * change `assert` into `self.assertIn` * format Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
c1b20e42f5
commit
c4d78f01de
@ -158,6 +158,7 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast):
|
||||
self.remove_space = remove_space
|
||||
self.keep_accents = keep_accents
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
@ -216,6 +217,12 @@ class AlbertTokenizerFast(PreTrainedTokenizerFast):
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -137,6 +137,7 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
@ -187,6 +188,12 @@ class BarthezTokenizerFast(PreTrainedTokenizerFast):
|
||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -138,6 +138,7 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
@ -227,6 +228,12 @@ class BigBirdTokenizerFast(PreTrainedTokenizerFast):
|
||||
return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -135,6 +135,7 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
@ -186,6 +187,12 @@ class CamembertTokenizerFast(PreTrainedTokenizerFast):
|
||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -22,10 +22,7 @@ from .tokenization_herbert import HerbertTokenizer
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
VOCAB_FILES_NAMES = {
|
||||
"vocab_file": "vocab.json",
|
||||
"merges_file": "merges.txt",
|
||||
}
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "vocab.json", "merges_file": "merges.txt", "tokenizer_file": "tokenizer.json"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {
|
||||
"vocab_file": {
|
||||
|
@ -145,6 +145,7 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
self.lang_code_to_id = {
|
||||
lang_code: self.convert_tokens_to_ids(lang_code) for lang_code in FAIRSEQ_LANGUAGE_CODES
|
||||
@ -258,6 +259,12 @@ class MBart50TokenizerFast(PreTrainedTokenizerFast):
|
||||
return inputs
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -148,6 +148,7 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def _special_token_mask(self, seq):
|
||||
all_special_ids = set(self.all_special_ids) # call it once instead of inside list comp
|
||||
@ -192,6 +193,12 @@ class PegasusTokenizerFast(PreTrainedTokenizerFast):
|
||||
return token_ids_0 + token_ids_1 + [self.eos_token_id]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -104,8 +104,15 @@ class ReformerTokenizerFast(PreTrainedTokenizerFast):
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -137,9 +137,16 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
self._extra_ids = extra_ids
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -145,6 +145,7 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
|
||||
)
|
||||
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
@ -198,6 +199,12 @@ class XLMRobertaTokenizerFast(PreTrainedTokenizerFast):
|
||||
return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0]
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory.")
|
||||
return
|
||||
|
@ -164,6 +164,7 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
|
||||
self.remove_space = remove_space
|
||||
self.keep_accents = keep_accents
|
||||
self.vocab_file = vocab_file
|
||||
self.can_save_slow_tokenizer = False if not self.vocab_file else True
|
||||
|
||||
def build_inputs_with_special_tokens(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
@ -222,6 +223,12 @@ class XLNetTokenizerFast(PreTrainedTokenizerFast):
|
||||
return len(token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1] + cls_segment_id
|
||||
|
||||
def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
if not self.can_save_slow_tokenizer:
|
||||
raise ValueError(
|
||||
"Your fast tokenizer does not have the necessary information to save the vocabulary for a slow "
|
||||
"tokenizer."
|
||||
)
|
||||
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
|
@ -87,6 +87,7 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
"""
|
||||
|
||||
slow_tokenizer_class: PreTrainedTokenizer = None
|
||||
can_save_slow_tokenizer: bool = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
tokenizer_object = kwargs.pop("tokenizer_object", None)
|
||||
@ -551,7 +552,11 @@ class PreTrainedTokenizerFast(PreTrainedTokenizerBase):
|
||||
"might consider leaving the legacy_format at `None` or setting it to `False`."
|
||||
)
|
||||
|
||||
save_slow = (legacy_format is None or legacy_format is True) and self.slow_tokenizer_class is not None
|
||||
save_slow = (
|
||||
(legacy_format is None or legacy_format is True)
|
||||
and self.slow_tokenizer_class is not None
|
||||
and self.can_save_slow_tokenizer
|
||||
)
|
||||
save_fast = legacy_format is None or legacy_format is False
|
||||
|
||||
if save_slow:
|
||||
|
@ -38,6 +38,8 @@ from transformers import (
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
SpecialTokensMixin,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
)
|
||||
@ -56,6 +58,10 @@ from transformers.testing_utils import (
|
||||
from transformers.tokenization_utils import AddedToken
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, TFPreTrainedModel
|
||||
|
||||
@ -3389,6 +3395,27 @@ class TokenizerTesterMixin:
|
||||
)
|
||||
)
|
||||
|
||||
@require_torch
|
||||
def test_saving_tokenizer_trainer(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save the fast tokenizer files in a temporary directory
|
||||
tokenizer_old = self.rust_tokenizer_class.from_pretrained(pretrained_name, **kwargs, use_fast=True)
|
||||
tokenizer_old.save_pretrained(tmp_dir, legacy_format=False) # save only fast version
|
||||
|
||||
# Initialize toy model for the trainer
|
||||
model = nn.Module()
|
||||
|
||||
# Load tokenizer from a folder without legacy files
|
||||
tokenizer = self.rust_tokenizer_class.from_pretrained(tmp_dir)
|
||||
training_args = TrainingArguments(output_dir=tmp_dir, do_train=True, no_cuda=True)
|
||||
trainer = Trainer(model=model, args=training_args, tokenizer=tokenizer)
|
||||
|
||||
# Should not raise an error
|
||||
trainer.save_model(os.path.join(tmp_dir, "checkpoint"))
|
||||
self.assertIn("tokenizer.json", os.listdir(os.path.join(tmp_dir, "checkpoint")))
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
Loading…
Reference in New Issue
Block a user