From ade7371a41145c53fe90f66da7e5606f96068e98 Mon Sep 17 00:00:00 2001 From: SaulLu <55560583+SaulLu@users.noreply.github.com> Date: Thu, 27 Jan 2022 16:24:51 +0100 Subject: [PATCH] improve saving strategy of sentencepiece tokenizer (#15328) * add new test * add a feature to same the sentencepiece tokenizer model when the init file was deleted * update marian * update m2m_100 * fix marian * update speech to text * override test for layoutxlm * fix saving bartpho * remove harcoded values bartpho * special token string version * finish bartpho * override layoutxml test * add mbart * move special tokens list * format * Revert "format" This reverts commit 37a40df37903a932c2f951cbd33acb684246bae7. * simplify list of string of special tokens * Re-write `self.fairseq_tokens_to_ids ` initialization logic with special tokens Co-authored-by: Sylvain Gugger Co-authored-by: Sylvain Gugger --- .../models/albert/tokenization_albert.py | 6 ++- .../models/bartpho/tokenization_bartpho.py | 29 ++++++++++--- .../tokenization_bert_generation.py | 6 ++- .../models/big_bird/tokenization_big_bird.py | 6 ++- .../camembert/tokenization_camembert.py | 6 ++- .../models/fnet/tokenization_fnet.py | 6 ++- .../layoutxlm/tokenization_layoutxlm.py | 6 ++- .../models/m2m_100/tokenization_m2m_100.py | 7 ++- .../models/marian/tokenization_marian.py | 43 +++++++++++++------ .../models/mbart/tokenization_mbart.py | 6 ++- .../models/mbart50/tokenization_mbart50.py | 6 ++- .../models/pegasus/tokenization_pegasus.py | 6 ++- .../models/reformer/tokenization_reformer.py | 6 ++- .../tokenization_speech_to_text.py | 8 +++- src/transformers/models/t5/tokenization_t5.py | 7 ++- .../tokenization_xlm_prophetnet.py | 6 ++- .../xlm_roberta/tokenization_xlm_roberta.py | 6 ++- .../models/xlnet/tokenization_xlnet.py | 6 ++- tests/test_tokenization_common.py | 27 ++++++++++++ tests/test_tokenization_layoutxlm.py | 38 ++++++++++++++++ tests/test_tokenization_mbart.py | 1 + 21 files changed, 202 insertions(+), 36 deletions(-) diff --git a/src/transformers/models/albert/tokenization_albert.py b/src/transformers/models/albert/tokenization_albert.py index ba873afbaba..cfcfcd9daa1 100644 --- a/src/transformers/models/albert/tokenization_albert.py +++ b/src/transformers/models/albert/tokenization_albert.py @@ -343,7 +343,11 @@ class AlbertTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/bartpho/tokenization_bartpho.py b/src/transformers/models/bartpho/tokenization_bartpho.py index 0bc17876af3..b12a962eae3 100644 --- a/src/transformers/models/bartpho/tokenization_bartpho.py +++ b/src/transformers/models/bartpho/tokenization_bartpho.py @@ -157,12 +157,20 @@ class BartphoTokenizer(PreTrainedTokenizer): self.sp_model.Load(str(vocab_file)) # Load the reduced vocab - self.fairseq_tokens_to_ids = {"": 0, "": 1, "": 2, "": 3} + + # Keep order of special tokens for backward compatibility + self.fairseq_tokens_to_ids = {} + cnt = 0 + for token in [bos_token, pad_token, eos_token, unk_token, sep_token, cls_token]: + if str(token) not in self.fairseq_tokens_to_ids: + self.fairseq_tokens_to_ids[str(token)] = cnt + cnt += 1 with open(monolingual_vocab_file, "r", encoding="utf-8") as f: for line in f.readlines(): token = line.strip().split()[0] self.fairseq_tokens_to_ids[token] = len(self.fairseq_tokens_to_ids) - self.fairseq_tokens_to_ids[""] = len(self.fairseq_tokens_to_ids) + if str(mask_token) not in self.fairseq_tokens_to_ids: + self.fairseq_tokens_to_ids[str(mask_token)] = len(self.fairseq_tokens_to_ids) self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()} @@ -278,7 +286,7 @@ class BartphoTokenizer(PreTrainedTokenizer): if token in self.fairseq_tokens_to_ids: return self.fairseq_tokens_to_ids[token] else: - return self.fairseq_tokens_to_ids[""] + return self.unk_token_id def _convert_id_to_token(self, index): """Converts an index (integer) in a token (str) using the vocab.""" @@ -301,10 +309,21 @@ class BartphoTokenizer(PreTrainedTokenizer): (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["monolingual_vocab_file"], ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) - if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath(out_monolingual_vocab_file): + if os.path.abspath(self.monolingual_vocab_file) != os.path.abspath( + out_monolingual_vocab_file + ) and os.path.isfile(self.monolingual_vocab_file): copyfile(self.monolingual_vocab_file, out_monolingual_vocab_file) + elif not os.path.isfile(self.monolingual_vocab_file): + with open(out_monolingual_vocab_file, "w", encoding="utf-8") as fp: + for token in self.fairseq_tokens_to_ids: + if token not in self.all_special_tokens: + fp.write(f"{str(token)} \n") return out_vocab_file, out_monolingual_vocab_file diff --git a/src/transformers/models/bert_generation/tokenization_bert_generation.py b/src/transformers/models/bert_generation/tokenization_bert_generation.py index 66da1f0658c..e0e6a7ccb1c 100644 --- a/src/transformers/models/bert_generation/tokenization_bert_generation.py +++ b/src/transformers/models/bert_generation/tokenization_bert_generation.py @@ -160,7 +160,11 @@ class BertGenerationTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/big_bird/tokenization_big_bird.py b/src/transformers/models/big_bird/tokenization_big_bird.py index 536fdc0b1d7..19f507f92bf 100644 --- a/src/transformers/models/big_bird/tokenization_big_bird.py +++ b/src/transformers/models/big_bird/tokenization_big_bird.py @@ -189,8 +189,12 @@ class BigBirdTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/camembert/tokenization_camembert.py b/src/transformers/models/camembert/tokenization_camembert.py index 2d21c2e5bb5..60394148053 100644 --- a/src/transformers/models/camembert/tokenization_camembert.py +++ b/src/transformers/models/camembert/tokenization_camembert.py @@ -288,7 +288,11 @@ class CamembertTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/fnet/tokenization_fnet.py b/src/transformers/models/fnet/tokenization_fnet.py index 209e4f5229f..6143a9b08f2 100644 --- a/src/transformers/models/fnet/tokenization_fnet.py +++ b/src/transformers/models/fnet/tokenization_fnet.py @@ -305,7 +305,11 @@ class FNetTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py index 6928454f643..79608ac37a7 100644 --- a/src/transformers/models/layoutxlm/tokenization_layoutxlm.py +++ b/src/transformers/models/layoutxlm/tokenization_layoutxlm.py @@ -331,8 +331,12 @@ class LayoutXLMTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/m2m_100/tokenization_m2m_100.py b/src/transformers/models/m2m_100/tokenization_m2m_100.py index 4e54dfc12bc..e8f79101988 100644 --- a/src/transformers/models/m2m_100/tokenization_m2m_100.py +++ b/src/transformers/models/m2m_100/tokenization_m2m_100.py @@ -13,6 +13,7 @@ # limitations under the License. """Tokenization classes for M2M100.""" import json +import os from contextlib import contextmanager from pathlib import Path from shutil import copyfile @@ -312,8 +313,12 @@ class M2M100Tokenizer(PreTrainedTokenizer): save_json(self.encoder, vocab_save_path) - if not spm_save_path.exists(): + if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file): copyfile(self.spm_file, spm_save_path) + elif not os.path.isfile(self.spm_file): + with open(spm_save_path, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (str(vocab_save_path), str(spm_save_path)) diff --git a/src/transformers/models/marian/tokenization_marian.py b/src/transformers/models/marian/tokenization_marian.py index 487f96ad791..1526ddaea8a 100644 --- a/src/transformers/models/marian/tokenization_marian.py +++ b/src/transformers/models/marian/tokenization_marian.py @@ -11,8 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import json +import os import re import warnings from contextlib import contextmanager @@ -23,8 +23,11 @@ from typing import Any, Dict, List, Optional, Tuple, Union import sentencepiece from ...tokenization_utils import PreTrainedTokenizer +from ...utils import logging +logger = logging.get_logger(__name__) + VOCAB_FILES_NAMES = { "source_spm": "source.spm", "target_spm": "target.spm", @@ -277,21 +280,35 @@ class MarianTokenizer(PreTrainedTokenizer): return len(self.encoder) def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]: - save_dir = Path(save_directory) - assert save_dir.is_dir(), f"{save_directory} should be a directory" - save_json( - self.encoder, - save_dir / ((filename_prefix + "-" if filename_prefix else "") + self.vocab_files_names["vocab"]), + if not os.path.isdir(save_directory): + logger.error(f"Vocabulary path ({save_directory}) should be a directory") + return + saved_files = [] + out_vocab_file = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab"] ) - for orig, f in zip(["source.spm", "target.spm"], self.spm_files): - dest_path = save_dir / ((filename_prefix + "-" if filename_prefix else "") + Path(f).name) - if not dest_path.exists(): - copyfile(f, save_dir / orig) + save_json(self.encoder, out_vocab_file) + saved_files.append(out_vocab_file) - return tuple( - save_dir / ((filename_prefix + "-" if filename_prefix else "") + f) for f in self.vocab_files_names - ) + for spm_save_filename, spm_orig_path, spm_model in zip( + [VOCAB_FILES_NAMES["source_spm"], VOCAB_FILES_NAMES["target_spm"]], + self.spm_files, + [self.spm_source, self.spm_target], + ): + spm_save_path = os.path.join( + save_directory, (filename_prefix + "-" if filename_prefix else "") + spm_save_filename + ) + if os.path.abspath(spm_orig_path) != os.path.abspath(spm_save_path) and os.path.isfile(spm_orig_path): + copyfile(spm_orig_path, spm_save_path) + saved_files.append(spm_save_path) + elif not os.path.isfile(spm_orig_path): + with open(spm_save_path, "wb") as fi: + content_spiece_model = spm_model.serialized_model_proto() + fi.write(content_spiece_model) + saved_files.append(spm_save_path) + + return tuple(saved_files) def get_vocab(self) -> Dict: vocab = self.encoder.copy() diff --git a/src/transformers/models/mbart/tokenization_mbart.py b/src/transformers/models/mbart/tokenization_mbart.py index 0ddae116de5..e6d3ff43379 100644 --- a/src/transformers/models/mbart/tokenization_mbart.py +++ b/src/transformers/models/mbart/tokenization_mbart.py @@ -315,8 +315,12 @@ class MBartTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/mbart50/tokenization_mbart50.py b/src/transformers/models/mbart50/tokenization_mbart50.py index 1282e4774ab..c7e53c61495 100644 --- a/src/transformers/models/mbart50/tokenization_mbart50.py +++ b/src/transformers/models/mbart50/tokenization_mbart50.py @@ -245,8 +245,12 @@ class MBart50Tokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/pegasus/tokenization_pegasus.py b/src/transformers/models/pegasus/tokenization_pegasus.py index 6b147ea05ac..2cc5511fc4d 100644 --- a/src/transformers/models/pegasus/tokenization_pegasus.py +++ b/src/transformers/models/pegasus/tokenization_pegasus.py @@ -285,7 +285,11 @@ class PegasusTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/reformer/tokenization_reformer.py b/src/transformers/models/reformer/tokenization_reformer.py index 83377622f86..8c75dda15e7 100644 --- a/src/transformers/models/reformer/tokenization_reformer.py +++ b/src/transformers/models/reformer/tokenization_reformer.py @@ -167,7 +167,11 @@ class ReformerTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py index cf40fa713d4..7d77c945ced 100644 --- a/src/transformers/models/speech_to_text/tokenization_speech_to_text.py +++ b/src/transformers/models/speech_to_text/tokenization_speech_to_text.py @@ -13,8 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. """Tokenization classes for Speech2Text.""" - import json +import os from pathlib import Path from shutil import copyfile from typing import Any, Dict, List, Optional, Tuple, Union @@ -260,8 +260,12 @@ class Speech2TextTokenizer(PreTrainedTokenizer): save_json(self.encoder, vocab_save_path) - if not spm_save_path.exists(): + if os.path.abspath(self.spm_file) != os.path.abspath(spm_save_path) and os.path.isfile(self.spm_file): copyfile(self.spm_file, spm_save_path) + elif not os.path.isfile(self.spm_file): + with open(spm_save_path, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (str(vocab_save_path), str(spm_save_path)) diff --git a/src/transformers/models/t5/tokenization_t5.py b/src/transformers/models/t5/tokenization_t5.py index bb2fcb2e41d..a356aa70c18 100644 --- a/src/transformers/models/t5/tokenization_t5.py +++ b/src/transformers/models/t5/tokenization_t5.py @@ -303,8 +303,11 @@ class T5Tokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) - logger.info(f"Copy vocab file to {out_vocab_file}") + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py b/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py index 9004433acf1..48f68238f12 100644 --- a/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py +++ b/src/transformers/models/xlm_prophetnet/tokenization_xlm_prophetnet.py @@ -302,8 +302,12 @@ class XLMProphetNetTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py index 824b8279e9b..072933a12ea 100644 --- a/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py +++ b/src/transformers/models/xlm_roberta/tokenization_xlm_roberta.py @@ -310,7 +310,11 @@ class XLMRobertaTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/src/transformers/models/xlnet/tokenization_xlnet.py b/src/transformers/models/xlnet/tokenization_xlnet.py index 7f0c28d0c07..0dc7d9e72c1 100644 --- a/src/transformers/models/xlnet/tokenization_xlnet.py +++ b/src/transformers/models/xlnet/tokenization_xlnet.py @@ -342,7 +342,11 @@ class XLNetTokenizer(PreTrainedTokenizer): save_directory, (filename_prefix + "-" if filename_prefix else "") + VOCAB_FILES_NAMES["vocab_file"] ) - if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file): + if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file) and os.path.isfile(self.vocab_file): copyfile(self.vocab_file, out_vocab_file) + elif not os.path.isfile(self.vocab_file): + with open(out_vocab_file, "wb") as fi: + content_spiece_model = self.sp_model.serialized_model_proto() + fi.write(content_spiece_model) return (out_vocab_file,) diff --git a/tests/test_tokenization_common.py b/tests/test_tokenization_common.py index 3a6da1e1228..bee7ee72092 100644 --- a/tests/test_tokenization_common.py +++ b/tests/test_tokenization_common.py @@ -394,6 +394,33 @@ class TokenizerTesterMixin: self.assertEqual(tokenizer_new.sp_model_kwargs, sp_model_kwargs) self.check_subword_sampling(tokenizer_new) + def test_save_sentencepiece_tokenizer(self) -> None: + if not self.test_sentencepiece or not self.test_slow_tokenizer: + return + # We want to verify that we will be able to save the tokenizer even if the original files that were used to + # build the tokenizer have been deleted in the meantime. + text = "This is text to test the tokenizer." + + tokenizer_slow_1 = self.get_tokenizer() + encoding_tokenizer_slow_1 = tokenizer_slow_1(text) + + tmpdirname_1 = tempfile.mkdtemp() + tmpdirname_2 = tempfile.mkdtemp() + + tokenizer_slow_1.save_pretrained(tmpdirname_1) + tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1) + encoding_tokenizer_slow_2 = tokenizer_slow_2(text) + + shutil.rmtree(tmpdirname_1) + tokenizer_slow_2.save_pretrained(tmpdirname_2) + + tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2) + encoding_tokenizer_slow_3 = tokenizer_slow_3(text) + shutil.rmtree(tmpdirname_2) + + self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2) + self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3) + def test_model_input_names_signature(self): accepted_model_main_input_names = [ "input_ids", # nlp models diff --git a/tests/test_tokenization_layoutxlm.py b/tests/test_tokenization_layoutxlm.py index a0478971c60..b0643cfe685 100644 --- a/tests/test_tokenization_layoutxlm.py +++ b/tests/test_tokenization_layoutxlm.py @@ -99,6 +99,44 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase): output_text = "unwanted, running" return input_text, output_text + # override test in `test_tokenization_common.py` because of the required input format of the `__call__`` method of + # this tokenizer + def test_save_sentencepiece_tokenizer(self) -> None: + if not self.test_sentencepiece or not self.test_slow_tokenizer: + return + # We want to verify that we will be able to save the tokenizer even if the original files that were used to + # build the tokenizer have been deleted in the meantime. + words, boxes = self.get_words_and_boxes() + + tokenizer_slow_1 = self.get_tokenizer() + encoding_tokenizer_slow_1 = tokenizer_slow_1( + words, + boxes=boxes, + ) + + tmpdirname_1 = tempfile.mkdtemp() + tmpdirname_2 = tempfile.mkdtemp() + + tokenizer_slow_1.save_pretrained(tmpdirname_1) + tokenizer_slow_2 = self.tokenizer_class.from_pretrained(tmpdirname_1) + encoding_tokenizer_slow_2 = tokenizer_slow_2( + words, + boxes=boxes, + ) + + shutil.rmtree(tmpdirname_1) + tokenizer_slow_2.save_pretrained(tmpdirname_2) + + tokenizer_slow_3 = self.tokenizer_class.from_pretrained(tmpdirname_2) + encoding_tokenizer_slow_3 = tokenizer_slow_3( + words, + boxes=boxes, + ) + shutil.rmtree(tmpdirname_2) + + self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2) + self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3) + @slow def test_sequence_builders(self): tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base") diff --git a/tests/test_tokenization_mbart.py b/tests/test_tokenization_mbart.py index c2c50c95322..aa4868e8589 100644 --- a/tests/test_tokenization_mbart.py +++ b/tests/test_tokenization_mbart.py @@ -39,6 +39,7 @@ class MBartTokenizationTest(TokenizerTesterMixin, unittest.TestCase): tokenizer_class = MBartTokenizer rust_tokenizer_class = MBartTokenizerFast test_rust_tokenizer = True + test_sentencepiece = True def setUp(self): super().setUp()