mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[split_special_tokens
] Add support for split_special_tokens
argument to encode (#25081)
* draft changes * update and add tests * styling for no * move test * path to usable model * update test * small update * update bertbased tokenizers * don'tuse kwargs for _tokenize * don'tuse kwargs for _tokenize * fix copies * update * update test for special tokenizers * fixup * skip two tests * remove pdb breakpiont() * wowo * rewrite custom tests * nits * revert chang in target keys * fix markup lm * update documentation of the argument
This commit is contained in:
parent
9d7afd2536
commit
30b3c46ff5
@ -238,10 +238,12 @@ class BertTokenizer(PreTrainedTokenizer):
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -177,10 +177,12 @@ class ConvBertTokenizer(PreTrainedTokenizer):
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -178,10 +178,12 @@ class RetriBertTokenizer(PreTrainedTokenizer):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -195,10 +195,12 @@ class DistilBertTokenizer(PreTrainedTokenizer):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -194,10 +194,12 @@ class ElectraTokenizer(PreTrainedTokenizer):
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -205,10 +205,12 @@ class FunnelTokenizer(PreTrainedTokenizer):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -176,10 +176,12 @@ class LayoutLMTokenizer(PreTrainedTokenizer):
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -168,10 +168,12 @@ class LxmertTokenizer(PreTrainedTokenizer):
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -166,10 +166,12 @@ class MobileBertTokenizer(PreTrainedTokenizer):
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -210,10 +210,12 @@ class RoCBertTokenizer(PreTrainedTokenizer):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -180,10 +180,12 @@ class SqueezeBertTokenizer(PreTrainedTokenizer):
|
||||
def get_vocab(self):
|
||||
return dict(self.vocab, **self.added_tokens_encoder)
|
||||
|
||||
def _tokenize(self, text):
|
||||
def _tokenize(self, text, split_special_tokens=False):
|
||||
split_tokens = []
|
||||
if self.do_basic_tokenize:
|
||||
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
|
||||
for token in self.basic_tokenizer.tokenize(
|
||||
text, never_split=self.all_special_tokens if not split_special_tokens else None
|
||||
):
|
||||
# If the token is part of the never_split set
|
||||
if token in self.basic_tokenizer.never_split:
|
||||
split_tokens.append(token)
|
||||
|
@ -498,6 +498,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
all_special_tokens_extended = {
|
||||
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
|
||||
}
|
||||
split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
|
||||
|
||||
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
|
||||
|
||||
@ -513,8 +514,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
||||
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
|
||||
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
|
||||
|
||||
no_split_token = set(self.unique_no_split_tokens)
|
||||
tokens = self.tokens_trie.split(text)
|
||||
# split_special_tokens: empty `no_split_token`
|
||||
if split_special_tokens:
|
||||
no_split_token = []
|
||||
tokens = [text]
|
||||
else:
|
||||
no_split_token = set(self.unique_no_split_tokens)
|
||||
tokens = self.tokens_trie.split(text)
|
||||
|
||||
# ["This is something", "<special_token_1>", " else"]
|
||||
for i, token in enumerate(tokens):
|
||||
if token in no_split_token:
|
||||
|
@ -1492,6 +1492,11 @@ INIT_TOKENIZER_DOCSTRING = r"""
|
||||
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
|
||||
tokenization process.
|
||||
split_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the special tokens should be split during the tokenization process. The default behavior is
|
||||
to not split special tokens. This means that if `<s>` is the `bos_token`, then `tokenizer.tokenize("<s>") =
|
||||
['<s>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<s>")` will be give `['<',
|
||||
's', '>']`. This argument is only supported for `slow` tokenizers for the moment.
|
||||
"""
|
||||
|
||||
|
||||
@ -1546,6 +1551,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
|
||||
# By default, cleaning tokenization spaces for both fast and slow tokenizers
|
||||
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)
|
||||
|
||||
# By default, do not split special tokens for both fast and slow tokenizers
|
||||
self.split_special_tokens = kwargs.pop("split_special_tokens", False)
|
||||
|
||||
self.deprecation_warnings = (
|
||||
{}
|
||||
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).
|
||||
|
@ -384,6 +384,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_right_and_left_truncation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not implemented")
|
||||
def test_split_special_tokens(self):
|
||||
pass
|
||||
|
||||
def test_encode_plus_with_padding(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
|
@ -264,6 +264,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
def test_right_and_left_truncation(self):
|
||||
pass
|
||||
|
||||
@unittest.skip("Not implemented")
|
||||
def test_split_special_tokens(self):
|
||||
pass
|
||||
|
||||
def test_encode_plus_with_padding(self):
|
||||
tokenizers = self.get_tokenizers(do_lower_case=False)
|
||||
for tokenizer in tokenizers:
|
||||
|
@ -144,6 +144,19 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
|
||||
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
|
||||
|
||||
def test_split_special_tokens(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
|
||||
_, _, boxes = self.get_question_words_and_boxes()
|
||||
special_token = "[SPECIAL_TOKEN]"
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
|
||||
encoded_special_token = tokenizer.tokenize(special_token, boxes=boxes, add_special_tokens=False)
|
||||
self.assertEqual(len(encoded_special_token), 1)
|
||||
|
||||
encoded_split_special_token = tokenizer.tokenize(
|
||||
special_token, add_special_tokens=False, split_special_tokens=True, boxes=boxes
|
||||
)
|
||||
self.assertTrue(len(encoded_split_special_token) > 1)
|
||||
|
||||
@slow
|
||||
def test_sequence_builders(self):
|
||||
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
|
||||
|
@ -1344,6 +1344,19 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(special_token_id in p_output)
|
||||
self.assertTrue(special_token_id in cr_output)
|
||||
|
||||
def test_split_special_tokens(self):
|
||||
# TODO this is only possible for slow currently
|
||||
tokenizer = self.get_tokenizer()
|
||||
special_token = "[SPECIAL_TOKEN]"
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
|
||||
encoded_special_token = tokenizer.tokenize(special_token, add_special_tokens=False)
|
||||
self.assertEqual(len(encoded_special_token), 1)
|
||||
|
||||
encoded_split_special_token = tokenizer.tokenize(
|
||||
special_token, add_special_tokens=False, split_special_tokens=True
|
||||
)
|
||||
self.assertTrue(len(encoded_split_special_token) > 1)
|
||||
|
||||
def test_training_new_tokenizer(self):
|
||||
# This feature only exists for fast tokenizers
|
||||
if not self.test_rust_tokenizer:
|
||||
|
@ -3909,6 +3909,7 @@ class TokenizerTesterMixin:
|
||||
# Should not raise an error
|
||||
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
|
||||
|
||||
# TODO This is ran for all models but only tests bert...
|
||||
def test_clean_up_tokenization_spaces(self):
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
||||
assert tokenizer.clean_up_tokenization_spaces is True
|
||||
@ -3953,3 +3954,29 @@ class TokenizerTesterMixin:
|
||||
tokenizer.clean_up_tokenization_spaces = True
|
||||
decoded = tokenizer.decode(tokens)
|
||||
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
|
||||
|
||||
def test_split_special_tokens(self):
|
||||
if not self.test_slow_tokenizer:
|
||||
return
|
||||
|
||||
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
|
||||
special_token = "[SPECIAL_TOKEN]"
|
||||
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
|
||||
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
|
||||
|
||||
if not tokenizer.is_fast:
|
||||
# bloom, gptneox etc only have a fast
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
|
||||
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
|
||||
self.assertEqual(len(encoded_special_token), 1)
|
||||
|
||||
encoded_split_special_token = tokenizer.encode(
|
||||
special_token, add_special_tokens=False, split_special_tokens=True
|
||||
)
|
||||
if len(encoded_split_special_token) == 1:
|
||||
# if we have subword tokenization or special vocab
|
||||
self.assertTrue(
|
||||
encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token)
|
||||
)
|
||||
else:
|
||||
self.assertTrue(len(encoded_split_special_token) > 1)
|
||||
|
Loading…
Reference in New Issue
Block a user