Optimizes ByT5 tokenizer (#13119)

* Starting to optimize ByT5.

* Making ByT5Tokenizer faster.

* Even faster.

* Cleaning up.
This commit is contained in:
Nicolas Patry 2021-08-17 10:11:58 +02:00 committed by GitHub
parent 14e9d2954c
commit 6626d8a62f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,7 +15,6 @@
""" Tokenization class for model ByT5."""
import re
import warnings
from typing import Dict, List, Optional, Tuple
@ -92,17 +91,21 @@ class ByT5Tokenizer(PreTrainedTokenizer):
**kwargs,
)
self._extra_ids = extra_ids
self._utf_vocab_size = 2 ** 8 # utf is 8 bits
# define special tokens dict
self.special_tokens_encoder: Dict[int, str] = {
self.pad_token: 0,
self.eos_token: 1,
self.unk_token: 2,
}
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
self._num_special_tokens = len(self.special_tokens_encoder)
self._utf_vocab_size = 2 ** 8 # utf is 8 bits
self._extra_ids = extra_ids
n = len(additional_special_tokens)
for i, token in enumerate(additional_special_tokens):
self.special_tokens_encoder[token] = self.vocab_size + i - n - 1
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
@property
def vocab_size(self):
@ -196,35 +199,16 @@ class ByT5Tokenizer(PreTrainedTokenizer):
def _tokenize(self, text: str) -> List[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
def _sub_tokenize(sub_text):
character_list = list(sub_text)
utf_tokens_lists = [list(char.encode("utf-8")) for char in character_list]
sub_tokens = [chr(utf_token) for utf_tokens in utf_tokens_lists for utf_token in utf_tokens]
return sub_tokens
# split on special characters
pattern = f"({'|'.join(self.special_tokens_encoder.keys())})"
sub_texts = list(filter(None, re.split(pattern, text)))
tokens = []
for sub_text in sub_texts:
if sub_text in self.special_tokens_encoder.keys():
tokens += [sub_text]
else:
tokens += _sub_tokenize(sub_text)
tokens = list(text)
return tokens
def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab."""
if token.startswith("<extra_id_"):
match = re.match(r"<extra_id_(\d+)>", token)
num = int(match.group(1))
token_id = self.vocab_size - num - 1
elif token in self.special_tokens_encoder:
if token in self.special_tokens_encoder:
token_id = self.special_tokens_encoder[token]
elif len(token) > 1:
# token of length > 1 must be newly added tokens => set them to unk token
elif token in self.added_tokens_encoder:
token_id = self.added_tokens_encoder[token]
elif len(token) != 1:
token_id = self.unk_token_id
else:
token_id = ord(token) + self._num_special_tokens
@ -232,35 +216,23 @@ class ByT5Tokenizer(PreTrainedTokenizer):
def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab."""
if index < self._num_special_tokens:
if index in self.special_tokens_decoder:
token = self.special_tokens_decoder[index]
elif index < self._utf_vocab_size + self._num_special_tokens:
token = chr(index - self._num_special_tokens)
else:
token = f"<extra_id_{self.vocab_size - 1 - index}>"
token = chr(index - self._num_special_tokens)
return token
def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string."""
def _convert_sub_string(sub_chars):
byte_string = bytes([ord(char) for char in sub_chars])
return byte_string.decode("utf-8", errors="ignore")
string = ""
sub_chars = []
for token in tokens:
# if is special token
if len(token) > 1:
string += _convert_sub_string(sub_chars)
string += token
sub_chars = []
if token in self.special_tokens_decoder:
tok_string = self.special_tokens_decoder[token]
elif token in self.added_tokens_decoder:
tok_string = self.added_tokens_decoder[token]
else:
sub_chars.append(token)
# add remaining chars
string += _convert_sub_string(sub_chars)
tok_string = token
string += tok_string
return string
# ByT5Tokenizer has no vocab file