Optimizes ByT5 tokenizer (#13119)

* Starting to optimize ByT5.

* Making ByT5Tokenizer faster.

* Even faster.

* Cleaning up.
This commit is contained in:
Nicolas Patry 2021-08-17 10:11:58 +02:00 committed by GitHub
parent 14e9d2954c
commit 6626d8a62f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -15,7 +15,6 @@
""" Tokenization class for model ByT5.""" """ Tokenization class for model ByT5."""
import re
import warnings import warnings
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple
@ -92,17 +91,21 @@ class ByT5Tokenizer(PreTrainedTokenizer):
**kwargs, **kwargs,
) )
self._extra_ids = extra_ids
self._utf_vocab_size = 2 ** 8 # utf is 8 bits
# define special tokens dict # define special tokens dict
self.special_tokens_encoder: Dict[int, str] = { self.special_tokens_encoder: Dict[int, str] = {
self.pad_token: 0, self.pad_token: 0,
self.eos_token: 1, self.eos_token: 1,
self.unk_token: 2, self.unk_token: 2,
} }
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
self._num_special_tokens = len(self.special_tokens_encoder) self._num_special_tokens = len(self.special_tokens_encoder)
self._utf_vocab_size = 2 ** 8 # utf is 8 bits n = len(additional_special_tokens)
self._extra_ids = extra_ids for i, token in enumerate(additional_special_tokens):
self.special_tokens_encoder[token] = self.vocab_size + i - n - 1
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
@property @property
def vocab_size(self): def vocab_size(self):
@ -196,35 +199,16 @@ class ByT5Tokenizer(PreTrainedTokenizer):
def _tokenize(self, text: str) -> List[str]: def _tokenize(self, text: str) -> List[str]:
"""Take as input a string and return a list of strings (tokens) for words/sub-words""" """Take as input a string and return a list of strings (tokens) for words/sub-words"""
tokens = list(text)
def _sub_tokenize(sub_text):
character_list = list(sub_text)
utf_tokens_lists = [list(char.encode("utf-8")) for char in character_list]
sub_tokens = [chr(utf_token) for utf_tokens in utf_tokens_lists for utf_token in utf_tokens]
return sub_tokens
# split on special characters
pattern = f"({'|'.join(self.special_tokens_encoder.keys())})"
sub_texts = list(filter(None, re.split(pattern, text)))
tokens = []
for sub_text in sub_texts:
if sub_text in self.special_tokens_encoder.keys():
tokens += [sub_text]
else:
tokens += _sub_tokenize(sub_text)
return tokens return tokens
def _convert_token_to_id(self, token): def _convert_token_to_id(self, token):
"""Converts a token (str) in an id using the vocab.""" """Converts a token (str) in an id using the vocab."""
if token.startswith("<extra_id_"): if token in self.special_tokens_encoder:
match = re.match(r"<extra_id_(\d+)>", token)
num = int(match.group(1))
token_id = self.vocab_size - num - 1
elif token in self.special_tokens_encoder:
token_id = self.special_tokens_encoder[token] token_id = self.special_tokens_encoder[token]
elif len(token) > 1: elif token in self.added_tokens_encoder:
# token of length > 1 must be newly added tokens => set them to unk token token_id = self.added_tokens_encoder[token]
elif len(token) != 1:
token_id = self.unk_token_id token_id = self.unk_token_id
else: else:
token_id = ord(token) + self._num_special_tokens token_id = ord(token) + self._num_special_tokens
@ -232,35 +216,23 @@ class ByT5Tokenizer(PreTrainedTokenizer):
def _convert_id_to_token(self, index): def _convert_id_to_token(self, index):
"""Converts an index (integer) in a token (str) using the vocab.""" """Converts an index (integer) in a token (str) using the vocab."""
if index < self._num_special_tokens: if index in self.special_tokens_decoder:
token = self.special_tokens_decoder[index] token = self.special_tokens_decoder[index]
elif index < self._utf_vocab_size + self._num_special_tokens:
token = chr(index - self._num_special_tokens)
else: else:
token = f"<extra_id_{self.vocab_size - 1 - index}>" token = chr(index - self._num_special_tokens)
return token return token
def convert_tokens_to_string(self, tokens): def convert_tokens_to_string(self, tokens):
"""Converts a sequence of tokens (string) in a single string.""" """Converts a sequence of tokens (string) in a single string."""
def _convert_sub_string(sub_chars):
byte_string = bytes([ord(char) for char in sub_chars])
return byte_string.decode("utf-8", errors="ignore")
string = "" string = ""
sub_chars = []
for token in tokens: for token in tokens:
# if is special token if token in self.special_tokens_decoder:
if len(token) > 1: tok_string = self.special_tokens_decoder[token]
string += _convert_sub_string(sub_chars) elif token in self.added_tokens_decoder:
string += token tok_string = self.added_tokens_decoder[token]
sub_chars = []
else: else:
sub_chars.append(token) tok_string = token
string += tok_string
# add remaining chars
string += _convert_sub_string(sub_chars)
return string return string
# ByT5Tokenizer has no vocab file # ByT5Tokenizer has no vocab file