mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-30 17:52:35 +06:00
OpenAI GPT Tokenizer can fallback on using BERT BasicTokenizer
This commit is contained in:
parent
e7cfc46fc1
commit
c6bea08448
@ -26,6 +26,7 @@ from io import open
|
||||
from tqdm import tqdm
|
||||
|
||||
from .file_utils import cached_path
|
||||
from .tokenization import BasicTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -72,8 +73,9 @@ class OpenAIGPTTokenizer(object):
|
||||
"""
|
||||
BPE tokenizer. Peculiarities:
|
||||
- lower case all inputs
|
||||
- uses SpaCy tokenizer
|
||||
- special tokens: additional symbols (ex: "__classify__") to add to a vocabulary.
|
||||
- uses SpaCy tokenizer and ftfy for pre-BPE tokenization if they are installed, fallback to BERT's BasicTokenizer if not.
|
||||
- argument special_tokens and function set_special_tokens:
|
||||
can be used to add additional symbols (ex: "__classify__") to a vocabulary.
|
||||
"""
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||
@ -122,12 +124,15 @@ class OpenAIGPTTokenizer(object):
|
||||
try:
|
||||
import ftfy
|
||||
import spacy
|
||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
||||
self.fix_text = ftfy.fix_text
|
||||
except ImportError:
|
||||
raise ImportError("Please install ftfy and spacy to use OpenAI GPT tokenizer.")
|
||||
logger.warning("ftfy or spacy is not installed using BERT BasicTokenizer instead of SpaCy & ftfy.")
|
||||
self.nlp = BasicTokenizer(do_lower_case=True,
|
||||
never_split=special_tokens if special_tokens is not None else [])
|
||||
self.fix_text = None
|
||||
|
||||
self.max_len = max_len if max_len is not None else int(1e12)
|
||||
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
|
||||
self.fix_text = ftfy.fix_text
|
||||
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
|
||||
self.decoder = {v:k for k,v in self.encoder.items()}
|
||||
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
|
||||
@ -150,6 +155,9 @@ class OpenAIGPTTokenizer(object):
|
||||
return
|
||||
self.special_tokens = dict((tok, len(self.encoder) + i) for i, tok in enumerate(special_tokens))
|
||||
self.special_tokens_decoder = {v:k for k, v in self.special_tokens.items()}
|
||||
if self.fix_text is None:
|
||||
# Using BERT's BasicTokenizer: we can update the tokenizer
|
||||
self.nlp.never_split = special_tokens
|
||||
logger.info("Special tokens {}".format(self.special_tokens))
|
||||
|
||||
def bpe(self, token):
|
||||
@ -198,9 +206,16 @@ class OpenAIGPTTokenizer(object):
|
||||
def tokenize(self, text):
|
||||
""" Tokenize a string. """
|
||||
split_tokens = []
|
||||
text = self.nlp(text_standardize(self.fix_text(text)))
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||
if self.fix_text is None:
|
||||
# Using BERT's BasicTokenizer
|
||||
text = self.nlp.tokenize(text)
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token).split(' ')])
|
||||
else:
|
||||
# Using SpaCy & ftfy (original tokenization process of OpenAI GPT)
|
||||
text = self.nlp(text_standardize(self.fix_text(text)))
|
||||
for token in text:
|
||||
split_tokens.extend([t for t in self.bpe(token.text.lower()).split(' ')])
|
||||
return split_tokens
|
||||
|
||||
def convert_tokens_to_ids(self, tokens):
|
||||
@ -219,8 +234,8 @@ class OpenAIGPTTokenizer(object):
|
||||
if len(ids) > self.max_len:
|
||||
raise ValueError(
|
||||
"Token indices sequence length is longer than the specified maximum "
|
||||
" sequence length for this BERT model ({} > {}). Running this"
|
||||
" sequence through BERT will result in indexing errors".format(len(ids), self.max_len)
|
||||
" sequence length for this OpenAI GPT model ({} > {}). Running this"
|
||||
" sequence through the model will result in indexing errors".format(len(ids), self.max_len)
|
||||
)
|
||||
return ids
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user