mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
change the way sentinel tokens can retrived (#20373)
* change the way sentinel tokens can retrived * Fix line length for doc string * Fix line length for doc string * Add more stronger test for t5 tokenization * Format file changes * Make a stronger test for filtering sentinel tokens * fix file format issues
This commit is contained in:
parent
81d82e4f78
commit
03ae1f060b
@ -79,12 +79,11 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
extra_ids (`int`, *optional*, defaults to 100):
|
||||
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
|
||||
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
|
||||
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
|
||||
like in T5 preprocessing see
|
||||
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
|
||||
additional_special_tokens (`List[str]`, *optional*):
|
||||
Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are
|
||||
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be
|
||||
retrieved by calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids
|
||||
method
|
||||
additional_special_tokens (`List[str]`, *optional*):
|
||||
Additional special tokens used by the tokenizer.
|
||||
sp_model_kwargs (`dict`, *optional*):
|
||||
Will be passed to the `SentencePieceProcessor.__init__()` method. The [Python wrapper for
|
||||
@ -213,6 +212,14 @@ class T5Tokenizer(PreTrainedTokenizer):
|
||||
return ([0] * len(token_ids_0)) + [1]
|
||||
return ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def get_sentinel_tokens(self):
|
||||
return list(
|
||||
set(filter(lambda x: bool(re.search("<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
|
||||
)
|
||||
|
||||
def get_sentinel_token_ids(self):
|
||||
return [self._convert_token_to_id(token) for token in self.get_sentinel_tokens()]
|
||||
|
||||
def _add_eos_if_not_present(self, token_ids: List[int]) -> List[int]:
|
||||
"""Do not add eos again if user already added it."""
|
||||
if len(token_ids) > 0 and token_ids[-1] == self.eos_token_id:
|
||||
|
@ -16,6 +16,7 @@
|
||||
|
||||
|
||||
import os
|
||||
import re
|
||||
import warnings
|
||||
from shutil import copyfile
|
||||
from typing import List, Optional, Tuple
|
||||
@ -90,11 +91,9 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
pad_token (`str`, *optional*, defaults to `"<pad>"`):
|
||||
The token used for padding, for example when batching sequences of different lengths.
|
||||
extra_ids (`int`, *optional*, defaults to 100):
|
||||
Add a number of extra ids added to the end of the vocabulary for use as sentinels. These tokens are
|
||||
accessible as "<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. Extra tokens are
|
||||
indexed from the end of the vocabulary up to beginning ("<extra_id_0>" is the last token in the vocabulary
|
||||
like in T5 preprocessing see
|
||||
[here](https://github.com/google-research/text-to-text-transfer-transformer/blob/9fd7b14a769417be33bc6c850f9598764913c833/t5/data/preprocessors.py#L2117)).
|
||||
Add a number of extra ids added to the vocabulary for use as sentinels. These tokens are accessible as
|
||||
"<extra_id_{%d}>" where "{%d}" is a number between 0 and extra_ids-1. These tokens can be retrieved by
|
||||
calling get_sentinel_tokens method and token ids can be by calling get_sentinel_token_ids method
|
||||
additional_special_tokens (`List[str]`, *optional*):
|
||||
Additional special tokens used by the tokenizer.
|
||||
"""
|
||||
@ -235,3 +234,11 @@ class T5TokenizerFast(PreTrainedTokenizerFast):
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||
|
||||
def get_sentinel_tokens(self):
|
||||
return list(
|
||||
set(filter(lambda x: bool(re.search("<extra_id_\d+>", x)) is not None, self.additional_special_tokens))
|
||||
)
|
||||
|
||||
def get_sentinel_token_ids(self):
|
||||
return [self.convert_tokens_to_ids(token) for token in self.get_sentinel_tokens()]
|
||||
|
@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
@ -379,3 +380,25 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||
model_name="t5-base",
|
||||
revision="5a7ff2d8f5117c194c7e32ec1ccbf04642cca99b",
|
||||
)
|
||||
|
||||
def test_get_sentinel_tokens(self):
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
|
||||
sentinel_tokens = tokenizer.get_sentinel_tokens()
|
||||
self.assertEquals(len(sentinel_tokens), 10)
|
||||
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
|
||||
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
|
||||
|
||||
def test_get_sentinel_token_ids(self):
|
||||
tokenizer = T5Tokenizer(SAMPLE_VOCAB, extra_ids=10)
|
||||
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))
|
||||
|
||||
def test_get_sentinel_tokens_for_fasttokenizer(self):
|
||||
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
|
||||
sentinel_tokens = tokenizer.get_sentinel_tokens()
|
||||
self.assertEquals(len(sentinel_tokens), 10)
|
||||
self.assertListEqual(sorted(sentinel_tokens), sorted([f"<extra_id_{str(i)}>" for i in range(0, 10)]))
|
||||
self.assertTrue([re.search("<extra_id_\d+>", token) is not None for token in sentinel_tokens])
|
||||
|
||||
def test_get_sentinel_token_ids_for_fasttokenizer(self):
|
||||
tokenizer = T5TokenizerFast(SAMPLE_VOCAB, extra_ids=10)
|
||||
self.assertListEqual(sorted(tokenizer.get_sentinel_token_ids()), sorted([i for i in range(1000, 1010)]))
|
||||
|
Loading…
Reference in New Issue
Block a user