This commit is contained in:
Patrick von Platen 2021-09-24 14:57:49 +02:00 committed by GitHub
parent 0eabe49204
commit e579f855fa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -18,6 +18,7 @@ import unittest
from transformers import FNetTokenizer, FNetTokenizerFast
from transformers.testing_utils import require_sentencepiece, require_tokenizers, slow
from transformers.tokenization_utils import AddedToken
from .test_tokenization_common import TokenizerTesterMixin
@ -141,6 +142,56 @@ class FNetTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer.sep_token_id
]
# Overriden Tests - loading the fast tokenizer from slow just takes too long
def test_special_tokens_initialization(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
added_tokens = [AddedToken("<special>", lstrip=True)]
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
r_output = tokenizer_r.encode("Hey this is a <special> token")
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
self.assertTrue(special_token_id in r_output)
if self.test_slow_tokenizer:
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
p_output = tokenizer_p.encode("Hey this is a <special> token")
cr_output = tokenizer_r.encode("Hey this is a <special> token")
self.assertEqual(p_output, r_output)
self.assertEqual(cr_output, r_output)
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in cr_output)
@slow
def test_special_tokens_initialization_from_slow(self):
for tokenizer, pretrained_name, kwargs in self.tokenizers_list:
with self.subTest(f"{tokenizer.__class__.__name__} ({pretrained_name})"):
added_tokens = [AddedToken("<special>", lstrip=True)]
tokenizer_r = self.rust_tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs, from_slow=True
)
special_token_id = tokenizer_r.encode("<special>", add_special_tokens=False)[0]
tokenizer_p = self.tokenizer_class.from_pretrained(
pretrained_name, additional_special_tokens=added_tokens, **kwargs
)
p_output = tokenizer_p.encode("Hey this is a <special> token")
cr_output = tokenizer_r.encode("Hey this is a <special> token")
self.assertEqual(p_output, cr_output)
self.assertTrue(special_token_id in p_output)
self.assertTrue(special_token_id in cr_output)
# Overriden Tests
def test_padding(self, max_length=50):
if not self.test_slow_tokenizer: