mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
Optimizes ByT5 tokenizer (#13119)
* Starting to optimize ByT5. * Making ByT5Tokenizer faster. * Even faster. * Cleaning up.
This commit is contained in:
parent
14e9d2954c
commit
6626d8a62f
@ -15,7 +15,6 @@
|
|||||||
""" Tokenization class for model ByT5."""
|
""" Tokenization class for model ByT5."""
|
||||||
|
|
||||||
|
|
||||||
import re
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -92,17 +91,21 @@ class ByT5Tokenizer(PreTrainedTokenizer):
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self._extra_ids = extra_ids
|
||||||
|
|
||||||
|
self._utf_vocab_size = 2 ** 8 # utf is 8 bits
|
||||||
|
|
||||||
# define special tokens dict
|
# define special tokens dict
|
||||||
self.special_tokens_encoder: Dict[int, str] = {
|
self.special_tokens_encoder: Dict[int, str] = {
|
||||||
self.pad_token: 0,
|
self.pad_token: 0,
|
||||||
self.eos_token: 1,
|
self.eos_token: 1,
|
||||||
self.unk_token: 2,
|
self.unk_token: 2,
|
||||||
}
|
}
|
||||||
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
|
|
||||||
|
|
||||||
self._num_special_tokens = len(self.special_tokens_encoder)
|
self._num_special_tokens = len(self.special_tokens_encoder)
|
||||||
self._utf_vocab_size = 2 ** 8 # utf is 8 bits
|
n = len(additional_special_tokens)
|
||||||
self._extra_ids = extra_ids
|
for i, token in enumerate(additional_special_tokens):
|
||||||
|
self.special_tokens_encoder[token] = self.vocab_size + i - n - 1
|
||||||
|
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vocab_size(self):
|
def vocab_size(self):
|
||||||
@ -196,35 +199,16 @@ class ByT5Tokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def _tokenize(self, text: str) -> List[str]:
|
def _tokenize(self, text: str) -> List[str]:
|
||||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
||||||
|
tokens = list(text)
|
||||||
def _sub_tokenize(sub_text):
|
|
||||||
character_list = list(sub_text)
|
|
||||||
utf_tokens_lists = [list(char.encode("utf-8")) for char in character_list]
|
|
||||||
sub_tokens = [chr(utf_token) for utf_tokens in utf_tokens_lists for utf_token in utf_tokens]
|
|
||||||
return sub_tokens
|
|
||||||
|
|
||||||
# split on special characters
|
|
||||||
pattern = f"({'|'.join(self.special_tokens_encoder.keys())})"
|
|
||||||
sub_texts = list(filter(None, re.split(pattern, text)))
|
|
||||||
tokens = []
|
|
||||||
for sub_text in sub_texts:
|
|
||||||
if sub_text in self.special_tokens_encoder.keys():
|
|
||||||
tokens += [sub_text]
|
|
||||||
else:
|
|
||||||
tokens += _sub_tokenize(sub_text)
|
|
||||||
|
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
def _convert_token_to_id(self, token):
|
def _convert_token_to_id(self, token):
|
||||||
"""Converts a token (str) in an id using the vocab."""
|
"""Converts a token (str) in an id using the vocab."""
|
||||||
if token.startswith("<extra_id_"):
|
if token in self.special_tokens_encoder:
|
||||||
match = re.match(r"<extra_id_(\d+)>", token)
|
|
||||||
num = int(match.group(1))
|
|
||||||
token_id = self.vocab_size - num - 1
|
|
||||||
elif token in self.special_tokens_encoder:
|
|
||||||
token_id = self.special_tokens_encoder[token]
|
token_id = self.special_tokens_encoder[token]
|
||||||
elif len(token) > 1:
|
elif token in self.added_tokens_encoder:
|
||||||
# token of length > 1 must be newly added tokens => set them to unk token
|
token_id = self.added_tokens_encoder[token]
|
||||||
|
elif len(token) != 1:
|
||||||
token_id = self.unk_token_id
|
token_id = self.unk_token_id
|
||||||
else:
|
else:
|
||||||
token_id = ord(token) + self._num_special_tokens
|
token_id = ord(token) + self._num_special_tokens
|
||||||
@ -232,35 +216,23 @@ class ByT5Tokenizer(PreTrainedTokenizer):
|
|||||||
|
|
||||||
def _convert_id_to_token(self, index):
|
def _convert_id_to_token(self, index):
|
||||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||||
if index < self._num_special_tokens:
|
if index in self.special_tokens_decoder:
|
||||||
token = self.special_tokens_decoder[index]
|
token = self.special_tokens_decoder[index]
|
||||||
elif index < self._utf_vocab_size + self._num_special_tokens:
|
|
||||||
token = chr(index - self._num_special_tokens)
|
|
||||||
else:
|
else:
|
||||||
token = f"<extra_id_{self.vocab_size - 1 - index}>"
|
token = chr(index - self._num_special_tokens)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def convert_tokens_to_string(self, tokens):
|
def convert_tokens_to_string(self, tokens):
|
||||||
"""Converts a sequence of tokens (string) in a single string."""
|
"""Converts a sequence of tokens (string) in a single string."""
|
||||||
|
|
||||||
def _convert_sub_string(sub_chars):
|
|
||||||
byte_string = bytes([ord(char) for char in sub_chars])
|
|
||||||
return byte_string.decode("utf-8", errors="ignore")
|
|
||||||
|
|
||||||
string = ""
|
string = ""
|
||||||
sub_chars = []
|
|
||||||
for token in tokens:
|
for token in tokens:
|
||||||
# if is special token
|
if token in self.special_tokens_decoder:
|
||||||
if len(token) > 1:
|
tok_string = self.special_tokens_decoder[token]
|
||||||
string += _convert_sub_string(sub_chars)
|
elif token in self.added_tokens_decoder:
|
||||||
string += token
|
tok_string = self.added_tokens_decoder[token]
|
||||||
sub_chars = []
|
|
||||||
else:
|
else:
|
||||||
sub_chars.append(token)
|
tok_string = token
|
||||||
|
string += tok_string
|
||||||
# add remaining chars
|
|
||||||
string += _convert_sub_string(sub_chars)
|
|
||||||
|
|
||||||
return string
|
return string
|
||||||
|
|
||||||
# ByT5Tokenizer has no vocab file
|
# ByT5Tokenizer has no vocab file
|
||||||
|
Loading…
Reference in New Issue
Block a user