Enable added tokens (#11325)

* Fix tests

* Reorganize

* Update tests/test_modeling_mobilebert.py

* Remove unnecessary addition
This commit is contained in:
Lysandre Debut 2021-05-04 14:13:57 +02:00 committed by GitHub
parent c40c7e213b
commit 09b0bcfea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 136 additions and 26 deletions

View File

@ -58,19 +58,37 @@ class HerbertTokenizer(XLMTokenizer):
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, *args, **kwargs):
def __init__(
self,
vocab_file,
merges_file,
tokenizer_file=None,
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
sep_token="</s>",
do_lowercase_and_remove_accent=False,
**kwargs
):
kwargs["cls_token"] = "<s>"
kwargs["unk_token"] = "<unk>"
kwargs["pad_token"] = "<pad>"
kwargs["mask_token"] = "<mask>"
kwargs["sep_token"] = "</s>"
kwargs["do_lowercase_and_remove_accent"] = False
kwargs["additional_special_tokens"] = []
super().__init__(*args, **kwargs)
super().__init__(
vocab_file,
merges_file,
tokenizer_file=None,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
sep_token=sep_token,
do_lowercase_and_remove_accent=do_lowercase_and_remove_accent,
**kwargs,
)
self.bert_pre_tokenizer = BasicTokenizer(
do_lower_case=False, never_split=self.all_special_tokens, tokenize_chinese_chars=False, strip_accents=False
do_lower_case=False,
never_split=self.all_special_tokens,
tokenize_chinese_chars=False,
strip_accents=False,
)
def _tokenize(self, text):

View File

@ -65,18 +65,28 @@ class HerbertTokenizerFast(PreTrainedTokenizerFast):
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
slow_tokenizer_class = HerbertTokenizer
def __init__(self, vocab_file, merges_file, tokenizer_file=None, **kwargs):
kwargs["cls_token"] = "<s>"
kwargs["unk_token"] = "<unk>"
kwargs["pad_token"] = "<pad>"
kwargs["mask_token"] = "<mask>"
kwargs["sep_token"] = "</s>"
def __init__(
self,
vocab_file,
merges_file,
tokenizer_file=None,
cls_token="<s>",
unk_token="<unk>",
pad_token="<pad>",
mask_token="<mask>",
sep_token="</s>",
**kwargs
):
super().__init__(
vocab_file,
merges_file,
tokenizer_file=tokenizer_file,
cls_token=cls_token,
unk_token=unk_token,
pad_token=pad_token,
mask_token=mask_token,
sep_token=sep_token,
**kwargs,
)

View File

@ -97,8 +97,17 @@ class MBartTokenizer(XLMRobertaTokenizer):
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
def __init__(
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs
):
super().__init__(
*args,
tokenizer_file=tokenizer_file,
src_lang=src_lang,
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.sp_model_size = len(self.sp_model)
self.lang_code_to_id = {
@ -111,6 +120,9 @@ class MBartTokenizer(XLMRobertaTokenizer):
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
self._additional_special_tokens = list(self.lang_code_to_id.keys())
if additional_special_tokens is not None:
self._additional_special_tokens.extend(additional_special_tokens)
self._src_lang = src_lang if src_lang is not None else "en_XX"
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
self.tgt_lang = tgt_lang

View File

@ -112,10 +112,24 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
prefix_tokens: List[int] = []
suffix_tokens: List[int] = []
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
def __init__(
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs
):
super().__init__(
*args,
tokenizer_file=tokenizer_file,
src_lang=src_lang,
tgt_lang=tgt_lang,
additional_special_tokens=additional_special_tokens,
**kwargs,
)
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
if additional_special_tokens is not None:
_additional_special_tokens.extend(additional_special_tokens)
self.add_special_tokens({"additional_special_tokens": _additional_special_tokens})
self._src_lang = src_lang if src_lang is not None else "en_XX"
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)

View File

@ -107,7 +107,7 @@ class T5Tokenizer(PreTrainedTokenizer):
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
elif extra_ids > 0 and additional_special_tokens is not None:
# Check that we have the right number of extra_id special tokens
extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens)))
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "

View File

@ -118,7 +118,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
elif extra_ids > 0 and additional_special_tokens is not None:
# Check that we have the right number of extra special tokens
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in x), additional_special_tokens)))
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in str(x)), additional_special_tokens)))
if extra_tokens != extra_ids:
raise ValueError(
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "

View File

@ -2872,6 +2872,34 @@ class TokenizerTesterMixin:
for key in python_output:
self.assertEqual(python_output[key], rust_output[key])
def test_special_tokens_initialization(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
added_tokens = [AddedToken("<special>", lstrip=True)]
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
)
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
p_output = tokenizer_p.encode("Hey this is a <special> token")
r_output = tokenizer_r.encode("Hey this is a <special> token")
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in r_output)
self.assertTrue(special_token_id in cr_output)
@is_staging_test
class TokenizerPushToHubTester(unittest.TestCase):

View File

@ -16,7 +16,7 @@
import unittest
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, T5Tokenizer, T5TokenizerFast
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
@ -246,3 +246,31 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
slow_text = self.t5_base_tokenizer.decode(fast_ids)
self.assertEqual(tgt_text, fast_text)
self.assertEqual(tgt_text, slow_text)
def test_special_tokens_initialization(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
added_tokens = [f"<extra_id_{i}>" for i in range(100)] + [AddedToken("<special>", lstrip=True)]
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
)
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
p_output = tokenizer_p.encode("Hey this is a <special> token")
r_output = tokenizer_r.encode("Hey this is a <special> token")
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in r_output)
self.assertTrue(special_token_id in cr_output)