mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Enable added tokens (#11325)
* Fix tests * Reorganize * Update tests/test_modeling_mobilebert.py * Remove unnecessary addition
This commit is contained in:
parent
c40c7e213b
commit
09b0bcfea9
@ -58,19 +58,37 @@ class HerbertTokenizer(XLMTokenizer):
|
||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
tokenizer_file=None,
|
||||
cls_token="<s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
mask_token="<mask>",
|
||||
sep_token="</s>",
|
||||
do_lowercase_and_remove_accent=False,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
kwargs["cls_token"] = "<s>"
|
||||
kwargs["unk_token"] = "<unk>"
|
||||
kwargs["pad_token"] = "<pad>"
|
||||
kwargs["mask_token"] = "<mask>"
|
||||
kwargs["sep_token"] = "</s>"
|
||||
kwargs["do_lowercase_and_remove_accent"] = False
|
||||
kwargs["additional_special_tokens"] = []
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
merges_file,
|
||||
tokenizer_file=None,
|
||||
cls_token=cls_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
mask_token=mask_token,
|
||||
sep_token=sep_token,
|
||||
do_lowercase_and_remove_accent=do_lowercase_and_remove_accent,
|
||||
**kwargs,
|
||||
)
|
||||
self.bert_pre_tokenizer = BasicTokenizer(
|
||||
do_lower_case=False, never_split=self.all_special_tokens, tokenize_chinese_chars=False, strip_accents=False
|
||||
do_lower_case=False,
|
||||
never_split=self.all_special_tokens,
|
||||
tokenize_chinese_chars=False,
|
||||
strip_accents=False,
|
||||
)
|
||||
|
||||
def _tokenize(self, text):
|
||||
|
@ -65,18 +65,28 @@ class HerbertTokenizerFast(PreTrainedTokenizerFast):
|
||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||
slow_tokenizer_class = HerbertTokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, tokenizer_file=None, **kwargs):
|
||||
|
||||
kwargs["cls_token"] = "<s>"
|
||||
kwargs["unk_token"] = "<unk>"
|
||||
kwargs["pad_token"] = "<pad>"
|
||||
kwargs["mask_token"] = "<mask>"
|
||||
kwargs["sep_token"] = "</s>"
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
merges_file,
|
||||
tokenizer_file=None,
|
||||
cls_token="<s>",
|
||||
unk_token="<unk>",
|
||||
pad_token="<pad>",
|
||||
mask_token="<mask>",
|
||||
sep_token="</s>",
|
||||
**kwargs
|
||||
):
|
||||
|
||||
super().__init__(
|
||||
vocab_file,
|
||||
merges_file,
|
||||
tokenizer_file=tokenizer_file,
|
||||
cls_token=cls_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
mask_token=mask_token,
|
||||
sep_token=sep_token,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
@ -97,8 +97,17 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
|
||||
def __init__(
|
||||
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs
|
||||
):
|
||||
super().__init__(
|
||||
*args,
|
||||
tokenizer_file=tokenizer_file,
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.sp_model_size = len(self.sp_model)
|
||||
self.lang_code_to_id = {
|
||||
@ -111,6 +120,9 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||
|
||||
if additional_special_tokens is not None:
|
||||
self._additional_special_tokens.extend(additional_special_tokens)
|
||||
|
||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||
self.tgt_lang = tgt_lang
|
||||
|
@ -112,10 +112,24 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
||||
prefix_tokens: List[int] = []
|
||||
suffix_tokens: List[int] = []
|
||||
|
||||
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
|
||||
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
|
||||
def __init__(
|
||||
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs
|
||||
):
|
||||
super().__init__(
|
||||
*args,
|
||||
tokenizer_file=tokenizer_file,
|
||||
src_lang=src_lang,
|
||||
tgt_lang=tgt_lang,
|
||||
additional_special_tokens=additional_special_tokens,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
|
||||
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
|
||||
|
||||
if additional_special_tokens is not None:
|
||||
_additional_special_tokens.extend(additional_special_tokens)
|
||||
|
||||
self.add_special_tokens({"additional_special_tokens": _additional_special_tokens})
|
||||
|
||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
|
||||
|
@ -107,7 +107,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
|
||||
elif extra_ids > 0 and additional_special_tokens is not None:
|
||||
# Check that we have the right number of extra_id special tokens
|
||||
extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens)))
|
||||
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
|
||||
if extra_tokens != extra_ids:
|
||||
raise ValueError(
|
||||
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
|
||||
|
@ -118,7 +118,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
|
||||
elif extra_ids > 0 and additional_special_tokens is not None:
|
||||
# Check that we have the right number of extra special tokens
|
||||
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in x), additional_special_tokens)))
|
||||
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in str(x)), additional_special_tokens)))
|
||||
if extra_tokens != extra_ids:
|
||||
raise ValueError(
|
||||
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
|
||||
|
@ -2872,6 +2872,34 @@ class TokenizerTesterMixin:
|
||||
for key in python_output:
|
||||
self.assertEqual(python_output[key], rust_output[key])
|
||||
|
||||
def test_special_tokens_initialization(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
|
||||
added_tokens = [AddedToken("<special>", lstrip=True)]
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
|
||||
)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
|
||||
p_output = tokenizer_p.encode("Hey this is a <special> token")
|
||||
r_output = tokenizer_r.encode("Hey this is a <special> token")
|
||||
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
|
||||
|
||||
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
|
||||
|
||||
self.assertEqual(p_output, r_output)
|
||||
self.assertEqual(cr_output, r_output)
|
||||
self.assertTrue(special_token_id in p_output)
|
||||
self.assertTrue(special_token_id in r_output)
|
||||
self.assertTrue(special_token_id in cr_output)
|
||||
|
||||
|
||||
@is_staging_test
|
||||
class TokenizerPushToHubTester(unittest.TestCase):
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
import unittest
|
||||
|
||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
|
||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||
|
||||
@ -246,3 +246,31 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
slow_text = self.t5_base_tokenizer.decode(fast_ids)
|
||||
self.assertEqual(tgt_text, fast_text)
|
||||
self.assertEqual(tgt_text, slow_text)
|
||||
|
||||
def test_special_tokens_initialization(self):
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
|
||||
added_tokens = [f"<extra_id_{i}>" for i in range(100)] + [AddedToken("<special>", lstrip=True)]
|
||||
|
||||
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
|
||||
)
|
||||
tokenizer_p = self.tokenizer_class.from_pretrained(
|
||||
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||
)
|
||||
|
||||
p_output = tokenizer_p.encode("Hey this is a <special> token")
|
||||
r_output = tokenizer_r.encode("Hey this is a <special> token")
|
||||
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
|
||||
|
||||
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
|
||||
|
||||
self.assertEqual(p_output, r_output)
|
||||
self.assertEqual(cr_output, r_output)
|
||||
self.assertTrue(special_token_id in p_output)
|
||||
self.assertTrue(special_token_id in r_output)
|
||||
self.assertTrue(special_token_id in cr_output)
|
||||
|
Loading…
Reference in New Issue
Block a user