change the way sentinel tokens can retrived (#20373)

* change the way sentinel tokens can retrived

* Fix line length for doc string

* Fix line length for doc string

* Add more stronger test for t5 tokenization

* Format file changes

* Make a stronger test for filtering sentinel tokens

* fix file format issues
This commit is contained in:
raghavanone 2022-11-23 20:05:44 +05:30 committed by GitHub
parent 81d82e4f78
commit 03ae1f060b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 48 additions and 11 deletions

View File

@ -79,12 +79,11 @@ class T5Tokenizer(PreTrainedTokenizer):
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
extra_ids (`int`, *optional*, defaults to 100):
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
like in T5 preprocessing see
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
additional_special_tokens (`List[str]`, *optional*):
Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be
retrieved by calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids
method
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
sp_model_kwargs (`dict`, *optional*):
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
@ -213,6 +212,14 @@ class T5Tokenizer(PreTrainedTokenizer):
return ([0] * len(token_ids_0)) + [1]
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
def get_sentinel_tokens(self):
return list(
set(filter(lambda x: bool(re.search("<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
)
def get_sentinel_token_ids(self):
return [self._convert_token_to_id(token) for token in self.get_sentinel_tokens()]
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
"""Do not add eos again if user already added it."""
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:

View File

@ -16,6 +16,7 @@
import os
import re
import warnings
from shutil import copyfile
from typing import List, Optional, Tuple
@ -90,11 +91,9 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
pad_token (`str`, *optional*, defaults to `"<pad>"`):
The token used for padding, for example when batching sequences of different lengths.
extra_ids (`int`, *optional*, defaults to 100):
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
like in T5 preprocessing see
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as
"<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by
calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
additional_special_tokens (`List[str]`, *optional*):
Additional special tokens used by the tokenizer.
"""
@ -235,3 +234,11 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
if token_ids_1 is None:
return len(token_ids_0 + eos) * [0]
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
def get_sentinel_tokens(self):
return list(
set(filter(lambda x: bool(re.search("<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
)
def get_sentinel_token_ids(self):
return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()]

View File

@ -14,6 +14,7 @@
# limitations under the License.
import json
import os
import re
import tempfile
import unittest
@ -379,3 +380,25 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
model_name="t5-base",
revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b",
)
def test_get_sentinel_tokens(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
sentinel_tokens = tokenizer.get_sentinel_tokens()
self.assertEquals(len(sentinel_tokens), 10)
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
def test_get_sentinel_token_ids(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))
def test_get_sentinel_tokens_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
sentinel_tokens = tokenizer.get_sentinel_tokens()
self.assertEquals(len(sentinel_tokens), 10)
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
def test_get_sentinel_token_ids_for_fasttokenizer(self):
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))