[cleanup] test_tokenization_common.py (#4390)

This commit is contained in:
Sam Shleifer 2020-05-19 10:46:55 -04:00 committed by GitHub
parent 8f1d047148
commit 07dd7c2fd8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 62 additions and 98 deletions

View File

@ -198,11 +198,12 @@ Follow these steps to start contributing:
are useful to avoid duplicated work, and to differentiate it from PRs ready
to be merged;
4. Make sure existing tests pass;
5. Add high-coverage tests. No quality test, no merge.
5. Add high-coverage tests. No quality testing = no merge.
- If you are adding a new model, make sure that you use `ModelTester.all_model_classes = (MyModel, MyModelWithLMHead,...)`, which triggers the common tests.
- If you are adding new `@slow` tests, make sure they pass using `RUN_SLOW=1 python -m pytest tests/test_my_new_model.py`.
- If you are adding a new tokenizer, write tests, and make sure `RUN_SLOW=1 python -m pytest tests/test_tokenization_{your_model_name}.py` passes.
CircleCI does not run them.
6. All public methods must have informative docstrings;
6. All public methods must have informative docstrings that work nicely with sphinx. See `modeling_ctrl.py` for an example.
### Tests

View File

@ -199,7 +199,7 @@ class RobertaTokenizer(GPT2Tokenizer):
if token_ids_1 is not None:
raise ValueError(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
"ids is already formatted with special tokens for the model."
)
return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))

View File

@ -771,26 +771,26 @@ class PreTrainedTokenizer(SpecialTokensMixin):
raise NotImplementedError
@property
def is_fast(self):
def is_fast(self) -> bool:
return False
@property
def max_len(self):
def max_len(self) -> int:
""" Kept here for backward compatibility.
Now renamed to `model_max_length` to avoid ambiguity.
"""
return self.model_max_length
@property
def max_len_single_sentence(self):
def max_len_single_sentence(self) -> int:
return self.model_max_length - self.num_special_tokens_to_add(pair=False)
@property
def max_len_sentences_pair(self):
def max_len_sentences_pair(self) -> int:
return self.model_max_length - self.num_special_tokens_to_add(pair=True)
@max_len_single_sentence.setter
def max_len_single_sentence(self, value):
def max_len_single_sentence(self, value) -> int:
""" For backward compatibility, allow to try to setup 'max_len_single_sentence' """
if value == self.model_max_length - self.num_special_tokens_to_add(pair=False):
logger.warning(
@ -802,7 +802,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
)
@max_len_sentences_pair.setter
def max_len_sentences_pair(self, value):
def max_len_sentences_pair(self, value) -> int:
""" For backward compatibility, allow to try to setup 'max_len_sentences_pair' """
if value == self.model_max_length - self.num_special_tokens_to_add(pair=True):
logger.warning(
@ -1118,7 +1118,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
return vocab_files + (special_tokens_map_file, added_tokens_file)
def save_vocabulary(self, save_directory):
def save_vocabulary(self, save_directory) -> Tuple[str]:
""" Save the tokenizer vocabulary to a directory. This method does *NOT* save added tokens
and special token mappings.
@ -1128,7 +1128,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
"""
raise NotImplementedError
def add_tokens(self, new_tokens):
def add_tokens(self, new_tokens: Union[str, List[str]]) -> int:
"""
Add a list of new tokens to the tokenizer class. If the new tokens are not in the
vocabulary, they are added to it with indices starting from length of the current vocabulary.
@ -1156,7 +1156,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if not isinstance(new_tokens, list):
new_tokens = [new_tokens]
to_add_tokens = []
tokens_to_add = []
for token in new_tokens:
assert isinstance(token, str)
if self.init_kwargs.get("do_lower_case", False) and token not in self.all_special_tokens:
@ -1164,18 +1164,18 @@ class PreTrainedTokenizer(SpecialTokensMixin):
if (
token != self.unk_token
and self.convert_tokens_to_ids(token) == self.convert_tokens_to_ids(self.unk_token)
and token not in to_add_tokens
and token not in tokens_to_add
):
to_add_tokens.append(token)
tokens_to_add.append(token)
logger.info("Adding %s to the vocabulary", token)
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(to_add_tokens))
added_tok_encoder = dict((tok, len(self) + i) for i, tok in enumerate(tokens_to_add))
added_tok_decoder = {v: k for k, v in added_tok_encoder.items()}
self.added_tokens_encoder.update(added_tok_encoder)
self.unique_added_tokens_encoder = set(self.added_tokens_encoder.keys()).union(set(self.all_special_tokens))
self.added_tokens_decoder.update(added_tok_decoder)
return len(to_add_tokens)
return len(tokens_to_add)
def num_special_tokens_to_add(self, pair=False):
"""
@ -2080,10 +2080,7 @@ class PreTrainedTokenizer(SpecialTokensMixin):
def build_inputs_with_special_tokens(self, token_ids_0: List, token_ids_1: Optional[List] = None) -> List:
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
A RoBERTa sequence has the following format:
single sequence: <s> X </s>
pair of sequences: <s> A </s></s> B </s>
by concatenating and adding special tokens. This implementation does not add special tokens.
"""
if token_ids_1 is None:
return token_ids_0

View File

@ -36,9 +36,6 @@ class AlbertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = AlbertTokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return AlbertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "this is a test"
output_text = "this is a test"

View File

@ -59,9 +59,6 @@ class BertTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs):
return BertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
return BertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)

View File

@ -60,9 +60,6 @@ class BertJapaneseTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
def get_tokenizer(self, **kwargs):
return BertJapaneseTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "こんにちは、世界。 \nこんばんは、世界。"
output_text = "こんにちは 、 世界 。 こんばんは 、 世界 。"

View File

@ -22,12 +22,12 @@ from collections import OrderedDict
from typing import TYPE_CHECKING, Dict, Tuple, Union
from tests.utils import require_tf, require_torch
from transformers import PreTrainedTokenizer
if TYPE_CHECKING:
from transformers import (
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
PreTrainedModel,
TFPreTrainedModel,
@ -67,19 +67,24 @@ class TokenizerTesterMixin:
def tearDown(self):
shutil.rmtree(self.tmpdirname)
def get_tokenizer(self, **kwargs):
raise NotImplementedError
def get_tokenizer(self, **kwargs) -> PreTrainedTokenizer:
return self.tokenizer_class.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
raise NotImplementedError
def get_input_output_texts(self):
raise NotImplementedError
def get_input_output_texts(self) -> Tuple[str, str]:
"""Feel free to overwrite"""
# TODO: @property
return (
"This is a test",
"This is a test",
)
@staticmethod
def convert_batch_encode_plus_format_to_encode_plus(batch_encode_plus_sequences):
# Switch from batch_encode_plus format: {'input_ids': [[...], [...]], ...}
# to the concatenated encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
# to the list of examples/ encode_plus format: [{'input_ids': [...], ...}, {'input_ids': [...], ...}]
return [
{value: batch_encode_plus_sequences[value][i] for value in batch_encode_plus_sequences.keys()}
for i in range(len(batch_encode_plus_sequences["input_ids"]))
@ -114,13 +119,13 @@ class TokenizerTesterMixin:
# Now let's start the test
tokenizer = self.get_tokenizer(max_len=42)
before_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
sample_text = "He is very happy, UNwant\u00E9d,running"
before_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
tokenizer.save_pretrained(self.tmpdirname)
tokenizer = self.tokenizer_class.from_pretrained(self.tmpdirname)
after_tokens = tokenizer.encode("He is very happy, UNwant\u00E9d,running", add_special_tokens=False)
after_tokens = tokenizer.encode(sample_text, add_special_tokens=False)
self.assertListEqual(before_tokens, after_tokens)
self.assertEqual(tokenizer.max_len, 42)
@ -128,6 +133,7 @@ class TokenizerTesterMixin:
self.assertEqual(tokenizer.max_len, 43)
def test_pickle_tokenizer(self):
"""Google pickle __getstate__ __setstate__ if you are struggling with this."""
tokenizer = self.get_tokenizer()
self.assertIsNotNone(tokenizer)
@ -253,7 +259,7 @@ class TokenizerTesterMixin:
decoded = tokenizer.decode(encoded, skip_special_tokens=True)
assert special_token not in decoded
def test_required_methods_tokenizer(self):
def test_internal_consistency(self):
tokenizer = self.get_tokenizer()
input_text, output_text = self.get_input_output_texts()
@ -263,13 +269,12 @@ class TokenizerTesterMixin:
self.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
self.assertNotEqual(len(tokens_2), 0)
text_2 = tokenizer.decode(ids)
self.assertIsInstance(text_2, str)
self.assertEqual(text_2, output_text)
self.assertNotEqual(len(tokens_2), 0)
self.assertIsInstance(text_2, str)
def test_encode_decode_with_spaces(self):
tokenizer = self.get_tokenizer()
@ -429,10 +434,7 @@ class TokenizerTesterMixin:
def test_special_tokens_mask(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
sequence_1 = "This one too please."
# Testing single inputs
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
@ -442,13 +444,13 @@ class TokenizerTesterMixin:
special_tokens_mask = encoded_sequence_dict["special_tokens_mask"]
self.assertEqual(len(special_tokens_mask), len(encoded_sequence_w_special))
filtered_sequence = [
(x if not special_tokens_mask[i] else None) for i, x in enumerate(encoded_sequence_w_special)
]
filtered_sequence = [x for x in filtered_sequence if x is not None]
filtered_sequence = [x for i, x in enumerate(encoded_sequence_w_special) if not special_tokens_mask[i]]
self.assertEqual(encoded_sequence, filtered_sequence)
# Testing inputs pairs
def test_special_tokens_mask_input_pairs(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
sequence_1 = "This one too please."
encoded_sequence = tokenizer.encode(sequence_0, add_special_tokens=False)
encoded_sequence += tokenizer.encode(sequence_1, add_special_tokens=False)
encoded_sequence_dict = tokenizer.encode_plus(
@ -464,7 +466,9 @@ class TokenizerTesterMixin:
filtered_sequence = [x for x in filtered_sequence if x is not None]
self.assertEqual(encoded_sequence, filtered_sequence)
# Testing with already existing special tokens
def test_special_tokens_mask_already_has_special_tokens(self):
tokenizer = self.get_tokenizer()
sequence_0 = "Encode this."
if tokenizer.cls_token_id == tokenizer.unk_token_id and tokenizer.cls_token_id == tokenizer.unk_token_id:
tokenizer.add_special_tokens({"cls_token": "</s>", "sep_token": "<s>"})
encoded_sequence_dict = tokenizer.encode_plus(
@ -514,13 +518,12 @@ class TokenizerTesterMixin:
tokenizer.padding_side = "right"
padded_sequence_right = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_right_length = len(padded_sequence_right)
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
tokenizer.padding_side = "left"
padded_sequence_left = tokenizer.encode(sequence, pad_to_max_length=True)
padded_sequence_left_length = len(padded_sequence_left)
assert sequence_length == padded_sequence_right_length
assert encoded_sequence == padded_sequence_right
assert sequence_length == padded_sequence_left_length
assert encoded_sequence == padded_sequence_left
@ -617,6 +620,9 @@ class TokenizerTesterMixin:
self.assertIsInstance(vocab, dict)
self.assertEqual(len(vocab), len(tokenizer))
def test_conversion_reversible(self):
tokenizer = self.get_tokenizer()
vocab = tokenizer.get_vocab()
for word, ind in vocab.items():
self.assertEqual(tokenizer.convert_tokens_to_ids(word), ind)
self.assertEqual(tokenizer.convert_ids_to_tokens(ind), word)
@ -746,6 +752,7 @@ class TokenizerTesterMixin:
@require_torch
def test_torch_encode_plus_sent_to_model(self):
import torch
from transformers import MODEL_MAPPING, TOKENIZER_MAPPING
MODEL_TOKENIZER_MAPPING = merge_model_tokenizer_mappings(MODEL_MAPPING, TOKENIZER_MAPPING)
@ -773,8 +780,10 @@ class TokenizerTesterMixin:
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
# This should not fail
model(**encoded_sequence)
model(**batch_encoded_sequence)
with torch.no_grad(): # saves some time
model(**encoded_sequence)
model(**batch_encoded_sequence)
if self.test_rust_tokenizer:
fast_tokenizer = self.get_rust_tokenizer()

View File

@ -24,9 +24,6 @@ class DistilBertTokenizationTest(BertTokenizationTest):
tokenizer_class = DistilBertTokenizer
def get_tokenizer(self, **kwargs):
return DistilBertTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_rust_tokenizer(self, **kwargs):
return DistilBertTokenizerFast.from_pretrained(self.tmpdirname, **kwargs)

View File

@ -64,13 +64,8 @@ class OpenAIGPTTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
return OpenAIGPTTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "lower newer"
output_text = "lower newer"
return input_text, output_text
return "lower newer", "lower newer"
def test_full_tokenizer(self):
tokenizer = OpenAIGPTTokenizer(self.vocab_file, self.merges_file)

View File

@ -37,14 +37,6 @@ class T5TokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = T5Tokenizer(SAMPLE_VOCAB)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return T5Tokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = T5Tokenizer(SAMPLE_VOCAB)

View File

@ -65,9 +65,6 @@ class XLMTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
with open(self.merges_file, "w") as fp:
fp.write("\n".join(merges))
def get_tokenizer(self, **kwargs):
return XLMTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "lower newer"
output_text = "lower newer"

View File

@ -17,6 +17,7 @@
import os
import unittest
from transformers.file_utils import cached_property
from transformers.tokenization_xlm_roberta import SPIECE_UNDERLINE, XLMRobertaTokenizer
from .test_tokenization_common import TokenizerTesterMixin
@ -37,14 +38,6 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLMRobertaTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = XLMRobertaTokenizer(SAMPLE_VOCAB, keep_accents=True)
@ -121,22 +114,22 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
],
)
@cached_property
def big_tokenizer(self):
return XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
@slow
def test_tokenization_base_easy_symbols(self):
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
symbols = "Hello World!"
original_tokenizer_encodings = [0, 35378, 6661, 38, 2]
# xlmr = torch.hub.load('pytorch/fairseq', 'xlmr.base') # xlmr.large has same tokenizer
# xlmr.eval()
# xlmr.encode(symbols)
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))
@slow
def test_tokenization_base_hard_symbols(self):
tokenizer = XLMRobertaTokenizer.from_pretrained("xlm-roberta-base")
symbols = 'This is a very long text with a lot of weird characters, such as: . , ~ ? ( ) " [ ] ! : - . Also we will add words that should not exsist and be tokenized to <unk>, such as saoneuhaoesuth'
original_tokenizer_encodings = [
0,
@ -209,4 +202,4 @@ class XLMRobertaTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# xlmr.eval()
# xlmr.encode(symbols)
self.assertListEqual(original_tokenizer_encodings, tokenizer.encode(symbols))
self.assertListEqual(original_tokenizer_encodings, self.big_tokenizer.encode(symbols))

View File

@ -37,14 +37,6 @@ class XLNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokenizer.save_pretrained(self.tmpdirname)
def get_tokenizer(self, **kwargs):
return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
def get_input_output_texts(self):
input_text = "This is a test"
output_text = "This is a test"
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)