mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Enable added tokens (#11325)
* Fix tests * Reorganize * Update tests/test_modeling_mobilebert.py * Remove unnecessary addition
This commit is contained in:
parent
c40c7e213b
commit
09b0bcfea9
@ -58,19 +58,37 @@ class HerbertTokenizer(XLMTokenizer):
|
|||||||
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
pretrained_init_configuration = PRETRAINED_INIT_CONFIGURATION
|
||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_file,
|
||||||
|
merges_file,
|
||||||
|
tokenizer_file=None,
|
||||||
|
cls_token="<s>",
|
||||||
|
unk_token="<unk>",
|
||||||
|
pad_token="<pad>",
|
||||||
|
mask_token="<mask>",
|
||||||
|
sep_token="</s>",
|
||||||
|
do_lowercase_and_remove_accent=False,
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
kwargs["cls_token"] = "<s>"
|
super().__init__(
|
||||||
kwargs["unk_token"] = "<unk>"
|
vocab_file,
|
||||||
kwargs["pad_token"] = "<pad>"
|
merges_file,
|
||||||
kwargs["mask_token"] = "<mask>"
|
tokenizer_file=None,
|
||||||
kwargs["sep_token"] = "</s>"
|
cls_token=cls_token,
|
||||||
kwargs["do_lowercase_and_remove_accent"] = False
|
unk_token=unk_token,
|
||||||
kwargs["additional_special_tokens"] = []
|
pad_token=pad_token,
|
||||||
|
mask_token=mask_token,
|
||||||
super().__init__(*args, **kwargs)
|
sep_token=sep_token,
|
||||||
|
do_lowercase_and_remove_accent=do_lowercase_and_remove_accent,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
self.bert_pre_tokenizer = BasicTokenizer(
|
self.bert_pre_tokenizer = BasicTokenizer(
|
||||||
do_lower_case=False, never_split=self.all_special_tokens, tokenize_chinese_chars=False, strip_accents=False
|
do_lower_case=False,
|
||||||
|
never_split=self.all_special_tokens,
|
||||||
|
tokenize_chinese_chars=False,
|
||||||
|
strip_accents=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _tokenize(self, text):
|
def _tokenize(self, text):
|
||||||
|
@ -65,18 +65,28 @@ class HerbertTokenizerFast(PreTrainedTokenizerFast):
|
|||||||
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
|
||||||
slow_tokenizer_class = HerbertTokenizer
|
slow_tokenizer_class = HerbertTokenizer
|
||||||
|
|
||||||
def __init__(self, vocab_file, merges_file, tokenizer_file=None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
kwargs["cls_token"] = "<s>"
|
vocab_file,
|
||||||
kwargs["unk_token"] = "<unk>"
|
merges_file,
|
||||||
kwargs["pad_token"] = "<pad>"
|
tokenizer_file=None,
|
||||||
kwargs["mask_token"] = "<mask>"
|
cls_token="<s>",
|
||||||
kwargs["sep_token"] = "</s>"
|
unk_token="<unk>",
|
||||||
|
pad_token="<pad>",
|
||||||
|
mask_token="<mask>",
|
||||||
|
sep_token="</s>",
|
||||||
|
**kwargs
|
||||||
|
):
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
vocab_file,
|
vocab_file,
|
||||||
merges_file,
|
merges_file,
|
||||||
tokenizer_file=tokenizer_file,
|
tokenizer_file=tokenizer_file,
|
||||||
|
cls_token=cls_token,
|
||||||
|
unk_token=unk_token,
|
||||||
|
pad_token=pad_token,
|
||||||
|
mask_token=mask_token,
|
||||||
|
sep_token=sep_token,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -97,8 +97,17 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
prefix_tokens: List[int] = []
|
prefix_tokens: List[int] = []
|
||||||
suffix_tokens: List[int] = []
|
suffix_tokens: List[int] = []
|
||||||
|
|
||||||
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
|
def __init__(
|
||||||
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
|
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
*args,
|
||||||
|
tokenizer_file=tokenizer_file,
|
||||||
|
src_lang=src_lang,
|
||||||
|
tgt_lang=tgt_lang,
|
||||||
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
self.sp_model_size = len(self.sp_model)
|
self.sp_model_size = len(self.sp_model)
|
||||||
self.lang_code_to_id = {
|
self.lang_code_to_id = {
|
||||||
@ -111,6 +120,9 @@ class MBartTokenizer(XLMRobertaTokenizer):
|
|||||||
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
self.fairseq_ids_to_tokens = {v: k for k, v in self.fairseq_tokens_to_ids.items()}
|
||||||
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
self._additional_special_tokens = list(self.lang_code_to_id.keys())
|
||||||
|
|
||||||
|
if additional_special_tokens is not None:
|
||||||
|
self._additional_special_tokens.extend(additional_special_tokens)
|
||||||
|
|
||||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||||
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
self.cur_lang_code_id = self.lang_code_to_id[self._src_lang]
|
||||||
self.tgt_lang = tgt_lang
|
self.tgt_lang = tgt_lang
|
||||||
|
@ -112,10 +112,24 @@ class MBartTokenizerFast(XLMRobertaTokenizerFast):
|
|||||||
prefix_tokens: List[int] = []
|
prefix_tokens: List[int] = []
|
||||||
suffix_tokens: List[int] = []
|
suffix_tokens: List[int] = []
|
||||||
|
|
||||||
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
|
def __init__(
|
||||||
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
|
self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, additional_special_tokens=None, **kwargs
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
*args,
|
||||||
|
tokenizer_file=tokenizer_file,
|
||||||
|
src_lang=src_lang,
|
||||||
|
tgt_lang=tgt_lang,
|
||||||
|
additional_special_tokens=additional_special_tokens,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
self.add_special_tokens({"additional_special_tokens": FAIRSEQ_LANGUAGE_CODES})
|
_additional_special_tokens = FAIRSEQ_LANGUAGE_CODES.copy()
|
||||||
|
|
||||||
|
if additional_special_tokens is not None:
|
||||||
|
_additional_special_tokens.extend(additional_special_tokens)
|
||||||
|
|
||||||
|
self.add_special_tokens({"additional_special_tokens": _additional_special_tokens})
|
||||||
|
|
||||||
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
self._src_lang = src_lang if src_lang is not None else "en_XX"
|
||||||
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
|
self.cur_lang_code = self.convert_tokens_to_ids(self._src_lang)
|
||||||
|
@ -107,7 +107,7 @@ class T5Tokenizer(PreTrainedTokenizer):
|
|||||||
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
|
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
|
||||||
elif extra_ids > 0 and additional_special_tokens is not None:
|
elif extra_ids > 0 and additional_special_tokens is not None:
|
||||||
# Check that we have the right number of extra_id special tokens
|
# Check that we have the right number of extra_id special tokens
|
||||||
extra_tokens = len(set(filter(lambda x: bool("extra_id" in x), additional_special_tokens)))
|
extra_tokens = len(set(filter(lambda x: bool("extra_id" in str(x)), additional_special_tokens)))
|
||||||
if extra_tokens != extra_ids:
|
if extra_tokens != extra_ids:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
|
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
|
||||||
|
@ -118,7 +118,7 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
|||||||
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
|
additional_special_tokens = [f"<extra_id_{i}>" for i in range(extra_ids)]
|
||||||
elif extra_ids > 0 and additional_special_tokens is not None:
|
elif extra_ids > 0 and additional_special_tokens is not None:
|
||||||
# Check that we have the right number of extra special tokens
|
# Check that we have the right number of extra special tokens
|
||||||
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in x), additional_special_tokens)))
|
extra_tokens = len(set(filter(lambda x: bool("extra_id_" in str(x)), additional_special_tokens)))
|
||||||
if extra_tokens != extra_ids:
|
if extra_tokens != extra_ids:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
|
f"Both extra_ids ({extra_ids}) and additional_special_tokens ({additional_special_tokens}) are provided to T5Tokenizer. "
|
||||||
|
@ -2872,6 +2872,34 @@ class TokenizerTesterMixin:
|
|||||||
for key in python_output:
|
for key in python_output:
|
||||||
self.assertEqual(python_output[key], rust_output[key])
|
self.assertEqual(python_output[key], rust_output[key])
|
||||||
|
|
||||||
|
def test_special_tokens_initialization(self):
|
||||||
|
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
|
|
||||||
|
added_tokens = [AddedToken("<special>", lstrip=True)]
|
||||||
|
|
||||||
|
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||||
|
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||||
|
)
|
||||||
|
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
|
||||||
|
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
|
||||||
|
)
|
||||||
|
tokenizer_p = self.tokenizer_class.from_pretrained(
|
||||||
|
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
p_output = tokenizer_p.encode("Hey this is a <special> token")
|
||||||
|
r_output = tokenizer_r.encode("Hey this is a <special> token")
|
||||||
|
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
|
||||||
|
|
||||||
|
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
|
||||||
|
|
||||||
|
self.assertEqual(p_output, r_output)
|
||||||
|
self.assertEqual(cr_output, r_output)
|
||||||
|
self.assertTrue(special_token_id in p_output)
|
||||||
|
self.assertTrue(special_token_id in r_output)
|
||||||
|
self.assertTrue(special_token_id in cr_output)
|
||||||
|
|
||||||
|
|
||||||
@is_staging_test
|
@is_staging_test
|
||||||
class TokenizerPushToHubTester(unittest.TestCase):
|
class TokenizerPushToHubTester(unittest.TestCase):
|
||||||
|
@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import SPIECE_UNDERLINE, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
from transformers import SPIECE_UNDERLINE, AddedToken, BatchEncoding, T5Tokenizer, T5TokenizerFast
|
||||||
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
|
from transformers.file_utils import cached_property, is_tf_available, is_torch_available
|
||||||
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
from transformers.testing_utils import get_tests_dir, require_sentencepiece, require_tokenizers
|
||||||
|
|
||||||
@ -246,3 +246,31 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
slow_text = self.t5_base_tokenizer.decode(fast_ids)
|
slow_text = self.t5_base_tokenizer.decode(fast_ids)
|
||||||
self.assertEqual(tgt_text, fast_text)
|
self.assertEqual(tgt_text, fast_text)
|
||||||
self.assertEqual(tgt_text, slow_text)
|
self.assertEqual(tgt_text, slow_text)
|
||||||
|
|
||||||
|
def test_special_tokens_initialization(self):
|
||||||
|
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||||
|
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||||
|
|
||||||
|
added_tokens = [f"<extra_id_{i}>" for i in range(100)] + [AddedToken("<special>", lstrip=True)]
|
||||||
|
|
||||||
|
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
|
||||||
|
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||||
|
)
|
||||||
|
tokenizer_cr = self.rust_tokenizer_class.from_pretrained(
|
||||||
|
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
|
||||||
|
)
|
||||||
|
tokenizer_p = self.tokenizer_class.from_pretrained(
|
||||||
|
pretrained_name, additional_special_tokens=added_tokens, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
p_output = tokenizer_p.encode("Hey this is a <special> token")
|
||||||
|
r_output = tokenizer_r.encode("Hey this is a <special> token")
|
||||||
|
cr_output = tokenizer_cr.encode("Hey this is a <special> token")
|
||||||
|
|
||||||
|
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
|
||||||
|
|
||||||
|
self.assertEqual(p_output, r_output)
|
||||||
|
self.assertEqual(cr_output, r_output)
|
||||||
|
self.assertTrue(special_token_id in p_output)
|
||||||
|
self.assertTrue(special_token_id in r_output)
|
||||||
|
self.assertTrue(special_token_id in cr_output)
|
||||||
|
Loading…
Reference in New Issue
Block a user