[MPNet] Add slow to fast tokenizer converter (#9233)

* add converter

* delet unnecessary comments
This commit is contained in:
Patrick von Platen 2020-12-21 15:41:34 +01:00 committed by GitHub
parent f4432b7e01
commit 9a12b9696f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 44 additions and 25 deletions

View File

@ -74,18 +74,6 @@ class BertConverter(Converter):
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
# # Let the tokenizer know about special tokens if they are part of the vocab
# if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)])
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
@ -125,18 +113,6 @@ class FunnelConverter(Converter):
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
# # Let the tokenizer know about special tokens if they are part of the vocab
# if tokenizer.token_to_id(str(self.original_tokenizer.unk_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.unk_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.sep_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.sep_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.cls_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.cls_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.pad_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.pad_token)])
# if tokenizer.token_to_id(str(self.original_tokenizer.mask_token)) is not None:
# tokenizer.add_special_tokens([str(self.original_tokenizer.mask_token)])
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
@ -171,6 +147,45 @@ class FunnelConverter(Converter):
return tokenizer
class MPNetConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.vocab
tokenizer = Tokenizer(WordPiece(vocab, unk_token=str(self.original_tokenizer.unk_token)))
tokenize_chinese_chars = False
strip_accents = False
do_lower_case = False
if hasattr(self.original_tokenizer, "basic_tokenizer"):
tokenize_chinese_chars = self.original_tokenizer.basic_tokenizer.tokenize_chinese_chars
strip_accents = self.original_tokenizer.basic_tokenizer.strip_accents
do_lower_case = self.original_tokenizer.basic_tokenizer.do_lower_case
tokenizer.normalizer = normalizers.BertNormalizer(
clean_text=True,
handle_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
lowercase=do_lower_case,
)
tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
cls = str(self.original_tokenizer.cls_token)
sep = str(self.original_tokenizer.sep_token)
cls_token_id = self.original_tokenizer.cls_token_id
sep_token_id = self.original_tokenizer.sep_token_id
tokenizer.post_processor = processors.TemplateProcessing(
single=f"{cls}:0 $A:0 {sep}:0",
pair=f"{cls}:0 $A:0 {sep}:0 {sep}:0 $B:1 {sep}:1", # MPNet uses two [SEP] tokens
special_tokens=[
(cls, cls_token_id),
(sep, sep_token_id),
],
)
tokenizer.decoder = decoders.WordPiece(prefix="##")
return tokenizer
class OpenAIGPTConverter(Converter):
def converted(self) -> Tokenizer:
vocab = self.original_tokenizer.encoder
@ -602,6 +617,7 @@ SLOW_TO_FAST_CONVERTERS = {
"LongformerTokenizer": RobertaConverter,
"LxmertTokenizer": BertConverter,
"MBartTokenizer": MBartConverter,
"MPNetTokenizer": MPNetConverter,
"MobileBertTokenizer": BertConverter,
"OpenAIGPTTokenizer": OpenAIGPTConverter,
"PegasusTokenizer": PegasusConverter,

View File

@ -17,6 +17,7 @@
import os
import unittest
from transformers import MPNetTokenizerFast
from transformers.models.mpnet.tokenization_mpnet import VOCAB_FILES_NAMES, MPNetTokenizer
from transformers.testing_utils import require_tokenizers, slow
@ -27,7 +28,9 @@ from .test_tokenization_common import TokenizerTesterMixin
class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
tokenizer_class = MPNetTokenizer
test_rust_tokenizer = False
rust_tokenizer_class = MPNetTokenizerFast
test_rust_tokenizer = True
space_between_special_tokens = True
def setUp(self):
super().setUp()