mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
[pegasus] Faster tokenizer tests (#7672)
This commit is contained in:
parent
bc00b37a0d
commit
b0f05e0c4c
20
scripts/pegasus/build_test_sample_spm_no_bos.py
Executable file
20
scripts/pegasus/build_test_sample_spm_no_bos.py
Executable file
@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# this script builds a small sample spm file tests/fixtures/test_sentencepiece_no_bos.model, with features needed by pegasus
|
||||||
|
|
||||||
|
# 1. pip install sentencepiece
|
||||||
|
#
|
||||||
|
# 2. wget https://raw.githubusercontent.com/google/sentencepiece/master/data/botchan.txt
|
||||||
|
|
||||||
|
# 3. build
|
||||||
|
import sentencepiece as spm
|
||||||
|
|
||||||
|
# pegasus:
|
||||||
|
# 1. no bos
|
||||||
|
# 2. eos_id is 1
|
||||||
|
# 3. unk_id is 2
|
||||||
|
# build a sample spm file accordingly
|
||||||
|
spm.SentencePieceTrainer.train('--input=botchan.txt --model_prefix=test_sentencepiece_no_bos --bos_id=-1 --unk_id=2 --eos_id=1 --vocab_size=1000')
|
||||||
|
|
||||||
|
# 4. now update the fixture
|
||||||
|
# mv test_sentencepiece_no_bos.model ../../tests/fixtures/
|
@ -184,13 +184,23 @@ def require_faiss(test_case):
|
|||||||
return test_case
|
return test_case
|
||||||
|
|
||||||
|
|
||||||
def get_tests_dir():
|
def get_tests_dir(append_path=None):
|
||||||
"""
|
"""
|
||||||
returns the full path to the `tests` dir, so that the tests can be invoked from anywhere
|
Args:
|
||||||
|
append_path: optional path to append to the tests dir path
|
||||||
|
|
||||||
|
Return:
|
||||||
|
The full path to the `tests` dir, so that the tests can be invoked from anywhere.
|
||||||
|
Optionally `append_path` is joined after the `tests` dir the former is provided.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# this function caller's __file__
|
# this function caller's __file__
|
||||||
caller__file__ = inspect.stack()[1][1]
|
caller__file__ = inspect.stack()[1][1]
|
||||||
return os.path.abspath(os.path.dirname(caller__file__))
|
tests_dir = os.path.abspath(os.path.dirname(caller__file__))
|
||||||
|
if append_path:
|
||||||
|
return os.path.join(tests_dir, append_path)
|
||||||
|
else:
|
||||||
|
return tests_dir
|
||||||
|
|
||||||
|
|
||||||
#
|
#
|
||||||
|
@ -49,7 +49,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
# Dont use reserved words added_token_encoder, added_tokens_decoder because of
|
# Don't use reserved words added_token_encoder, added_tokens_decoder because of
|
||||||
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
|
# AssertionError: Non-consecutive added token '1' found. in from_pretrained
|
||||||
assert len(self.added_tokens_decoder) == 0
|
assert len(self.added_tokens_decoder) == 0
|
||||||
self.encoder: Dict[int, str] = {0: self.pad_token, 1: self.eos_token}
|
self.encoder: Dict[int, str] = {0: self.pad_token, 1: self.eos_token}
|
||||||
@ -58,7 +58,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||||||
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
|
self.decoder: Dict[str, int] = {v: k for k, v in self.encoder.items()}
|
||||||
|
|
||||||
def _convert_token_to_id(self, token: str) -> int:
|
def _convert_token_to_id(self, token: str) -> int:
|
||||||
""" Converts a token (str) in an id using the vocab. """
|
""" Converts a token (str) to an id using the vocab. """
|
||||||
if token in self.decoder:
|
if token in self.decoder:
|
||||||
return self.decoder[token]
|
return self.decoder[token]
|
||||||
elif token in self.added_tokens_decoder:
|
elif token in self.added_tokens_decoder:
|
||||||
@ -67,7 +67,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||||||
return sp_id + self.offset
|
return sp_id + self.offset
|
||||||
|
|
||||||
def _convert_id_to_token(self, index: int) -> str:
|
def _convert_id_to_token(self, index: int) -> str:
|
||||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
"""Converts an index (integer) to a token (str) using the vocab."""
|
||||||
if index in self.encoder:
|
if index in self.encoder:
|
||||||
return self.encoder[index]
|
return self.encoder[index]
|
||||||
elif index in self.added_tokens_encoder:
|
elif index in self.added_tokens_encoder:
|
||||||
@ -81,11 +81,6 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||||||
def vocab_size(self) -> int:
|
def vocab_size(self) -> int:
|
||||||
return len(self.sp_model) + self.offset
|
return len(self.sp_model) + self.offset
|
||||||
|
|
||||||
def get_vocab(self) -> Dict[str, int]:
|
|
||||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
|
||||||
vocab.update(self.added_tokens_encoder)
|
|
||||||
return vocab
|
|
||||||
|
|
||||||
def num_special_tokens_to_add(self, pair=False):
|
def num_special_tokens_to_add(self, pair=False):
|
||||||
"""Just EOS"""
|
"""Just EOS"""
|
||||||
return 1
|
return 1
|
||||||
@ -109,7 +104,7 @@ class PegasusTokenizer(ReformerTokenizer):
|
|||||||
|
|
||||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None) -> List[int]:
|
||||||
"""
|
"""
|
||||||
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
|
Build model inputs from a sequence or a pair of sequences for sequence classification tasks
|
||||||
by concatenating and adding special tokens.
|
by concatenating and adding special tokens.
|
||||||
A Pegasus sequence has the following format, where ``X`` represents the sequence:
|
A Pegasus sequence has the following format, where ``X`` represents the sequence:
|
||||||
|
|
||||||
|
@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from shutil import copyfile
|
from shutil import copyfile
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
from .tokenization_utils import PreTrainedTokenizer
|
from .tokenization_utils import PreTrainedTokenizer
|
||||||
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
from .tokenization_utils_fast import PreTrainedTokenizerFast
|
||||||
@ -119,7 +120,7 @@ class ReformerTokenizer(PreTrainedTokenizer):
|
|||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
return self.sp_model.get_piece_size()
|
return self.sp_model.get_piece_size()
|
||||||
|
|
||||||
def get_vocab(self):
|
def get_vocab(self) -> Dict[str, int]:
|
||||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||||
vocab.update(self.added_tokens_encoder)
|
vocab.update(self.added_tokens_encoder)
|
||||||
return vocab
|
return vocab
|
||||||
|
@ -186,7 +186,7 @@ class PreTrainedTokenizer(PreTrainedTokenizerBase):
|
|||||||
|
|
||||||
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
num_added_toks = tokenizer.add_tokens(['new_tok1', 'my_new-tok2'])
|
||||||
print('We have added', num_added_toks, 'tokens')
|
print('We have added', num_added_toks, 'tokens')
|
||||||
# Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
# Note: resize_token_embeddings expects to receive the full size of the new vocabulary, i.e. the length of the tokenizer.
|
||||||
model.resize_token_embeddings(len(tokenizer))
|
model.resize_token_embeddings(len(tokenizer))
|
||||||
"""
|
"""
|
||||||
new_tokens = [str(tok) for tok in new_tokens]
|
new_tokens = [str(tok) for tok in new_tokens]
|
||||||
|
BIN
tests/fixtures/test_sentencepiece_no_bos.model
vendored
Normal file
BIN
tests/fixtures/test_sentencepiece_no_bos.model
vendored
Normal file
Binary file not shown.
@ -1,13 +1,15 @@
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import require_torch
|
from transformers.testing_utils import get_tests_dir, require_torch
|
||||||
from transformers.tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
|
from transformers.tokenization_pegasus import PegasusTokenizer, PegasusTokenizerFast
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
|
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece_no_bos.model")
|
||||||
|
|
||||||
|
|
||||||
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
||||||
|
|
||||||
tokenizer_class = PegasusTokenizer
|
tokenizer_class = PegasusTokenizer
|
||||||
@ -17,10 +19,8 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
save_dir = Path(self.tmpdirname)
|
# We have a SentencePiece fixture for testing
|
||||||
spm_file = PegasusTokenizer.vocab_files_names["vocab_file"]
|
tokenizer = PegasusTokenizer(SAMPLE_VOCAB)
|
||||||
if not (save_dir / spm_file).exists():
|
|
||||||
tokenizer = self.pegasus_large_tokenizer
|
|
||||||
tokenizer.save_pretrained(self.tmpdirname)
|
tokenizer.save_pretrained(self.tmpdirname)
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
@ -32,9 +32,6 @@ class PegasusTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
|
def get_tokenizer(self, **kwargs) -> PegasusTokenizer:
|
||||||
if not kwargs:
|
|
||||||
return self.pegasus_large_tokenizer
|
|
||||||
else:
|
|
||||||
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
return PegasusTokenizer.from_pretrained(self.tmpdirname, **kwargs)
|
||||||
|
|
||||||
def get_input_output_texts(self, tokenizer):
|
def get_input_output_texts(self, tokenizer):
|
||||||
|
@ -14,19 +14,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from transformers import BatchEncoding
|
from transformers import BatchEncoding
|
||||||
from transformers.file_utils import cached_property
|
from transformers.file_utils import cached_property
|
||||||
from transformers.testing_utils import _torch_available
|
from transformers.testing_utils import _torch_available, get_tests_dir
|
||||||
from transformers.tokenization_t5 import T5Tokenizer, T5TokenizerFast
|
from transformers.tokenization_t5 import T5Tokenizer, T5TokenizerFast
|
||||||
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
|
from transformers.tokenization_xlnet import SPIECE_UNDERLINE
|
||||||
|
|
||||||
from .test_tokenization_common import TokenizerTesterMixin
|
from .test_tokenization_common import TokenizerTesterMixin
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)), "fixtures/test_sentencepiece.model")
|
SAMPLE_VOCAB = get_tests_dir("fixtures/test_sentencepiece.model")
|
||||||
|
|
||||||
FRAMEWORK = "pt" if _torch_available else "tf"
|
FRAMEWORK = "pt" if _torch_available else "tf"
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user