[split_special_tokens] Add support for split_special_tokens argument to encode (#25081)

* draft changes

* update and add tests

* styling for no

* move test

* path to usable model

* update test

* small update

* update bertbased tokenizers

* don'tuse kwargs for _tokenize

* don'tuse kwargs for _tokenize

* fix copies

* update

* update test for special tokenizers

* fixup

* skip two tests

* remove pdb breakpiont()

* wowo

* rewrite custom tests

* nits

* revert chang in target keys

* fix markup lm

* update documentation of the argument
This commit is contained in:
Arthur 2023-08-18 13:26:27 +02:00 committed by GitHub
parent 9d7afd2536
commit 30b3c46ff5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 122 additions and 24 deletions

View File

@ -238,10 +238,12 @@ class BertTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -177,10 +177,12 @@ class ConvBertTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -178,10 +178,12 @@ class RetriBertTokenizer(PreTrainedTokenizer):
return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -195,10 +195,12 @@ class DistilBertTokenizer(PreTrainedTokenizer):
return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -194,10 +194,12 @@ class ElectraTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -205,10 +205,12 @@ class FunnelTokenizer(PreTrainedTokenizer):
return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -176,10 +176,12 @@ class LayoutLMTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -168,10 +168,12 @@ class LxmertTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -166,10 +166,12 @@ class MobileBertTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -210,10 +210,12 @@ class RoCBertTokenizer(PreTrainedTokenizer):
return dict(self.vocab, **self.added_tokens_encoder)
# Copied from transformers.models.bert.tokenization_bert.BertTokenizer._tokenize
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -180,10 +180,12 @@ class SqueezeBertTokenizer(PreTrainedTokenizer):
def get_vocab(self):
return dict(self.vocab, **self.added_tokens_encoder)
def _tokenize(self, text):
def _tokenize(self, text, split_special_tokens=False):
split_tokens = []
if self.do_basic_tokenize:
for token in self.basic_tokenizer.tokenize(text, never_split=self.all_special_tokens):
for token in self.basic_tokenizer.tokenize(
text, never_split=self.all_special_tokens if not split_special_tokens else None
):
# If the token is part of the never_split set
if token in self.basic_tokenizer.never_split:
split_tokens.append(token)

View File

@ -498,6 +498,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
all_special_tokens_extended = {
str(t): t for t in self.all_special_tokens_extended if isinstance(t, AddedToken)
}
split_special_tokens = kwargs.pop("split_special_tokens", self.split_special_tokens)
text, kwargs = self.prepare_for_tokenization(text, **kwargs)
@ -513,8 +514,14 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
pattern = r"(" + r"|".join(escaped_special_toks) + r")|" + r"(.+?)"
text = re.sub(pattern, lambda m: m.groups()[0] or m.groups()[1].lower(), text)
no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)
# split_special_tokens: empty `no_split_token`
if split_special_tokens:
no_split_token = []
tokens = [text]
else:
no_split_token = set(self.unique_no_split_tokens)
tokens = self.tokens_trie.split(text)
# ["This is something", "<special_token_1>", " else"]
for i, token in enumerate(tokens):
if token in no_split_token:

View File

@ -1492,6 +1492,11 @@ INIT_TOKENIZER_DOCSTRING = r"""
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not the model should cleanup the spaces that were added when splitting the input text during the
tokenization process.
split_special_tokens (`bool`, *optional*, defaults to `False`):
Whether or not the special tokens should be split during the tokenization process. The default behavior is
to not split special tokens. This means that if `<s>` is the `bos_token`, then `tokenizer.tokenize("<s>") =
['<s>`]. Otherwise, if `split_special_tokens=True`, then `tokenizer.tokenize("<s>")` will be give `['<',
's', '>']`. This argument is only supported for `slow` tokenizers for the moment.
"""
@ -1546,6 +1551,9 @@ class PreTrainedTokenizerBase(SpecialTokensMixin, PushToHubMixin):
# By default, cleaning tokenization spaces for both fast and slow tokenizers
self.clean_up_tokenization_spaces = kwargs.pop("clean_up_tokenization_spaces", True)
# By default, do not split special tokens for both fast and slow tokenizers
self.split_special_tokens = kwargs.pop("split_special_tokens", False)
self.deprecation_warnings = (
{}
) # Use to store when we have already noticed a deprecation warning (avoid overlogging).

View File

@ -384,6 +384,10 @@ class LayoutLMv2TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_right_and_left_truncation(self):
pass
@unittest.skip("Not implemented")
def test_split_special_tokens(self):
pass
def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:

View File

@ -264,6 +264,10 @@ class LayoutLMv3TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
def test_right_and_left_truncation(self):
pass
@unittest.skip("Not implemented")
def test_split_special_tokens(self):
pass
def test_encode_plus_with_padding(self):
tokenizers = self.get_tokenizers(do_lower_case=False)
for tokenizer in tokenizers:

View File

@ -144,6 +144,19 @@ class LayoutXLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_2)
self.assertEqual(encoding_tokenizer_slow_1, encoding_tokenizer_slow_3)
def test_split_special_tokens(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")
_, _, boxes = self.get_question_words_and_boxes()
special_token = "[SPECIAL_TOKEN]"
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.tokenize(special_token, boxes=boxes, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
encoded_split_special_token = tokenizer.tokenize(
special_token, add_special_tokens=False, split_special_tokens=True, boxes=boxes
)
self.assertTrue(len(encoded_split_special_token) > 1)
@slow
def test_sequence_builders(self):
tokenizer = self.tokenizer_class.from_pretrained("microsoft/layoutxlm-base")

View File

@ -1344,6 +1344,19 @@ class MarkupLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in cr_output)
def test_split_special_tokens(self):
# TODO this is only possible for slow currently
tokenizer = self.get_tokenizer()
special_token = "[SPECIAL_TOKEN]"
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.tokenize(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
encoded_split_special_token = tokenizer.tokenize(
special_token, add_special_tokens=False, split_special_tokens=True
)
self.assertTrue(len(encoded_split_special_token) > 1)
def test_training_new_tokenizer(self):
# This feature only exists for fast tokenizers
if not self.test_rust_tokenizer:

View File

@ -3909,6 +3909,7 @@ class TokenizerTesterMixin:
# Should not raise an error
self.rust_tokenizer_class.from_pretrained(tmp_dir_2)
# TODO This is ran for all models but only tests bert...
def test_clean_up_tokenization_spaces(self):
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
assert tokenizer.clean_up_tokenization_spaces is True
@ -3953,3 +3954,29 @@ class TokenizerTesterMixin:
tokenizer.clean_up_tokenization_spaces = True
decoded = tokenizer.decode(tokens)
assert decoded == "[CLS] this shouldn't be! he'll go. [SEP]"
def test_split_special_tokens(self):
if not self.test_slow_tokenizer:
return
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
special_token = "[SPECIAL_TOKEN]"
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
tokenizer = self.tokenizer_class.from_pretrained(pretrained_name, **kwargs)
if not tokenizer.is_fast:
# bloom, gptneox etc only have a fast
tokenizer.add_special_tokens({"additional_special_tokens": [special_token]})
encoded_special_token = tokenizer.encode(special_token, add_special_tokens=False)
self.assertEqual(len(encoded_special_token), 1)
encoded_split_special_token = tokenizer.encode(
special_token, add_special_tokens=False, split_special_tokens=True
)
if len(encoded_split_special_token) == 1:
# if we have subword tokenization or special vocab
self.assertTrue(
encoded_split_special_token[0] != tokenizer.convert_tokens_to_ids(special_token)
)
else:
self.assertTrue(len(encoded_split_special_token) > 1)