mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Optimizes ByT5 tokenizer (#13119)
* Starting to optimize ByT5. * Making ByT5Tokenizer faster. * Even faster. * Cleaning up.
This commit is contained in:
parent
14e9d2954c
commit
6626d8a62f
@ -15,7 +15,6 @@
|
||||
""" Tokenization class for model ByT5."""
|
||||
|
||||
|
||||
import re
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
@ -92,17 +91,21 @@ class ByT5Tokenizer(PreTrainedTokenizer):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self._extra_ids = extra_ids
|
||||
|
||||
self._utf_vocab_size = 2 ** 8 # utf is 8 bits
|
||||
|
||||
# define special tokens dict
|
||||
self.special_tokens_encoder: Dict[int, str] = {
|
||||
self.pad_token: 0,
|
||||
self.eos_token: 1,
|
||||
self.unk_token: 2,
|
||||
}
|
||||
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
|
||||
|
||||
self._num_special_tokens = len(self.special_tokens_encoder)
|
||||
self._utf_vocab_size = 2 ** 8 # utf is 8 bits
|
||||
self._extra_ids = extra_ids
|
||||
n = len(additional_special_tokens)
|
||||
for i, token in enumerate(additional_special_tokens):
|
||||
self.special_tokens_encoder[token] = self.vocab_size + i - n - 1
|
||||
self.special_tokens_decoder: Dict[str, int] = {v: k for k, v in self.special_tokens_encoder.items()}
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
@ -196,35 +199,16 @@ class ByT5Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
def _tokenize(self, text: str) -> List[str]:
|
||||
"""Take as input a string and return a list of strings (tokens) for words/sub-words"""
|
||||
|
||||
def _sub_tokenize(sub_text):
|
||||
character_list = list(sub_text)
|
||||
utf_tokens_lists = [list(char.encode("utf-8")) for char in character_list]
|
||||
sub_tokens = [chr(utf_token) for utf_tokens in utf_tokens_lists for utf_token in utf_tokens]
|
||||
return sub_tokens
|
||||
|
||||
# split on special characters
|
||||
pattern = f"({'|'.join(self.special_tokens_encoder.keys())})"
|
||||
sub_texts = list(filter(None, re.split(pattern, text)))
|
||||
tokens = []
|
||||
for sub_text in sub_texts:
|
||||
if sub_text in self.special_tokens_encoder.keys():
|
||||
tokens += [sub_text]
|
||||
else:
|
||||
tokens += _sub_tokenize(sub_text)
|
||||
|
||||
tokens = list(text)
|
||||
return tokens
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
if token.startswith("<extra_id_"):
|
||||
match = re.match(r"<extra_id_(\d+)>", token)
|
||||
num = int(match.group(1))
|
||||
token_id = self.vocab_size - num - 1
|
||||
elif token in self.special_tokens_encoder:
|
||||
if token in self.special_tokens_encoder:
|
||||
token_id = self.special_tokens_encoder[token]
|
||||
elif len(token) > 1:
|
||||
# token of length > 1 must be newly added tokens => set them to unk token
|
||||
elif token in self.added_tokens_encoder:
|
||||
token_id = self.added_tokens_encoder[token]
|
||||
elif len(token) != 1:
|
||||
token_id = self.unk_token_id
|
||||
else:
|
||||
token_id = ord(token) + self._num_special_tokens
|
||||
@ -232,35 +216,23 @@ class ByT5Tokenizer(PreTrainedTokenizer):
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
if index < self._num_special_tokens:
|
||||
if index in self.special_tokens_decoder:
|
||||
token = self.special_tokens_decoder[index]
|
||||
elif index < self._utf_vocab_size + self._num_special_tokens:
|
||||
token = chr(index - self._num_special_tokens)
|
||||
else:
|
||||
token = f"<extra_id_{self.vocab_size - 1 - index}>"
|
||||
token = chr(index - self._num_special_tokens)
|
||||
return token
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
|
||||
def _convert_sub_string(sub_chars):
|
||||
byte_string = bytes([ord(char) for char in sub_chars])
|
||||
return byte_string.decode("utf-8", errors="ignore")
|
||||
|
||||
string = ""
|
||||
sub_chars = []
|
||||
for token in tokens:
|
||||
# if is special token
|
||||
if len(token) > 1:
|
||||
string += _convert_sub_string(sub_chars)
|
||||
string += token
|
||||
sub_chars = []
|
||||
if token in self.special_tokens_decoder:
|
||||
tok_string = self.special_tokens_decoder[token]
|
||||
elif token in self.added_tokens_decoder:
|
||||
tok_string = self.added_tokens_decoder[token]
|
||||
else:
|
||||
sub_chars.append(token)
|
||||
|
||||
# add remaining chars
|
||||
string += _convert_sub_string(sub_chars)
|
||||
|
||||
tok_string = token
|
||||
string += tok_string
|
||||
return string
|
||||
|
||||
# ByT5Tokenizer has no vocab file
|
||||
|
Loading…
Reference in New Issue
Block a user