mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
add serialization semantics to tokenizers - fix transfo-xl tokenizer
This commit is contained in:
parent
616743330e
commit
3e65f255dc
@ -28,7 +28,7 @@ import math
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus
|
from pytorch_pretrained_bert import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
|
||||||
|
|
||||||
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
||||||
datefmt = '%m/%d/%Y %H:%M:%S',
|
datefmt = '%m/%d/%Y %H:%M:%S',
|
||||||
@ -80,6 +80,7 @@ def main():
|
|||||||
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
|
# The pre-processing involve computing word frequencies to prepare the Adaptive input and SoftMax
|
||||||
# and tokenizing the dataset
|
# and tokenizing the dataset
|
||||||
# The pre-processed corpus is a convertion (using the conversion script )
|
# The pre-processed corpus is a convertion (using the conversion script )
|
||||||
|
tokenizer = TransfoXLTokenizer.from_pretrained(args.model_name)
|
||||||
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
corpus = TransfoXLCorpus.from_pretrained(args.model_name)
|
||||||
ntokens = len(corpus.vocab)
|
ntokens = len(corpus.vocab)
|
||||||
|
|
||||||
|
@ -134,6 +134,19 @@ class BertTokenizer(object):
|
|||||||
tokens.append(self.ids_to_tokens[i])
|
tokens.append(self.ids_to_tokens[i])
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
|
def save_vocabulary(self, vocab_path):
|
||||||
|
"""Save the tokenizer vocabulary to a path."""
|
||||||
|
index = 0
|
||||||
|
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||||
|
with open(vocab_file, "w", encoding="utf-8") as writer:
|
||||||
|
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning("Saving vocabulary to {}: vocabulary indices are not consecutive."
|
||||||
|
" Please check that the vocabulary is not corrupted!".format(vocab_file))
|
||||||
|
index = token_index
|
||||||
|
writer.write(token + u'\n')
|
||||||
|
index += 1
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -187,6 +187,22 @@ class GPT2Tokenizer(object):
|
|||||||
self.cache[token] = word
|
self.cache[token] = word
|
||||||
return word
|
return word
|
||||||
|
|
||||||
|
def save_vocabulary(self, vocab_path):
|
||||||
|
"""Save the tokenizer vocabulary to a path."""
|
||||||
|
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||||
|
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||||
|
json.dump(self.encoder, vocab_file)
|
||||||
|
index = 0
|
||||||
|
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||||
|
writer.write(u'#version: 0.2\n')
|
||||||
|
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||||
|
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||||
|
index = token_index
|
||||||
|
writer.write(bpe_tokens + u'\n')
|
||||||
|
index += 1
|
||||||
|
|
||||||
def encode(self, text):
|
def encode(self, text):
|
||||||
bpe_tokens = []
|
bpe_tokens = []
|
||||||
for token in re.findall(self.pat, text):
|
for token in re.findall(self.pat, text):
|
||||||
|
@ -261,3 +261,19 @@ class OpenAIGPTTokenizer(object):
|
|||||||
).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
|
).replace(" 's", "'s").replace(" t ", "'t ").replace(" s ", "'s ").replace(" m ", "'m "
|
||||||
).replace(" 've", "'ve")
|
).replace(" 've", "'ve")
|
||||||
return out_string
|
return out_string
|
||||||
|
|
||||||
|
def save_vocabulary(self, vocab_path):
|
||||||
|
"""Save the tokenizer vocabulary to a path."""
|
||||||
|
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||||
|
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||||
|
json.dump(self.encoder, vocab_file)
|
||||||
|
index = 0
|
||||||
|
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||||
|
writer.write(u'#version: 0.2\n')
|
||||||
|
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
|
||||||
|
if index != token_index:
|
||||||
|
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||||
|
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||||
|
index = token_index
|
||||||
|
writer.write(bpe_tokens + u'\n')
|
||||||
|
index += 1
|
||||||
|
@ -63,7 +63,10 @@ class TransfoXLTokenizer(object):
|
|||||||
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
if pretrained_model_name_or_path in PRETRAINED_VOCAB_ARCHIVE_MAP:
|
||||||
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
vocab_file = PRETRAINED_VOCAB_ARCHIVE_MAP[pretrained_model_name_or_path]
|
||||||
else:
|
else:
|
||||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
if os.path.isdir(pretrained_model_name_or_path):
|
||||||
|
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||||
|
else:
|
||||||
|
vocab_file = pretrained_model_name_or_path
|
||||||
# redirect to the cache, if necessary
|
# redirect to the cache, if necessary
|
||||||
try:
|
try:
|
||||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||||
@ -141,6 +144,11 @@ class TransfoXLTokenizer(object):
|
|||||||
else:
|
else:
|
||||||
raise ValueError('No <unkown> token in vocabulary')
|
raise ValueError('No <unkown> token in vocabulary')
|
||||||
|
|
||||||
|
def save_vocabulary(self, vocab_path):
|
||||||
|
index = 0
|
||||||
|
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||||
|
torch.save(self.__dict__, vocab_file)
|
||||||
|
|
||||||
def build_vocab(self):
|
def build_vocab(self):
|
||||||
if self.vocab_file:
|
if self.vocab_file:
|
||||||
print('building vocab from {}'.format(self.vocab_file))
|
print('building vocab from {}'.format(self.vocab_file))
|
||||||
@ -245,82 +253,24 @@ class TransfoXLTokenizer(object):
|
|||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.idx2sym)
|
return len(self.idx2sym)
|
||||||
|
|
||||||
def _run_split_on_punc(self, text):
|
|
||||||
"""Splits punctuation on a piece of text."""
|
|
||||||
if text in self.never_split:
|
|
||||||
return [text]
|
|
||||||
chars = list(text)
|
|
||||||
i = 0
|
|
||||||
start_new_word = True
|
|
||||||
output = []
|
|
||||||
while i < len(chars):
|
|
||||||
char = chars[i]
|
|
||||||
if _is_punctuation(char):
|
|
||||||
output.append([char])
|
|
||||||
start_new_word = True
|
|
||||||
else:
|
|
||||||
if start_new_word:
|
|
||||||
output.append([])
|
|
||||||
start_new_word = False
|
|
||||||
output[-1].append(char)
|
|
||||||
i += 1
|
|
||||||
|
|
||||||
return ["".join(x) for x in output]
|
|
||||||
|
|
||||||
def _run_strip_accents(self, text):
|
|
||||||
"""Strips accents from a piece of text."""
|
|
||||||
text = unicodedata.normalize("NFD", text)
|
|
||||||
output = []
|
|
||||||
for char in text:
|
|
||||||
cat = unicodedata.category(char)
|
|
||||||
if cat == "Mn":
|
|
||||||
continue
|
|
||||||
output.append(char)
|
|
||||||
return "".join(output)
|
|
||||||
|
|
||||||
def _clean_text(self, text):
|
|
||||||
"""Performs invalid character removal and whitespace cleanup on text."""
|
|
||||||
output = []
|
|
||||||
for char in text:
|
|
||||||
cp = ord(char)
|
|
||||||
if cp == 0 or cp == 0xfffd or _is_control(char):
|
|
||||||
continue
|
|
||||||
if _is_whitespace(char):
|
|
||||||
output.append(" ")
|
|
||||||
else:
|
|
||||||
output.append(char)
|
|
||||||
return "".join(output)
|
|
||||||
|
|
||||||
def whitespace_tokenize(self, text):
|
|
||||||
"""Runs basic whitespace cleaning and splitting on a piece of text."""
|
|
||||||
text = text.strip()
|
|
||||||
if not text:
|
|
||||||
return []
|
|
||||||
if self.delimiter == '':
|
|
||||||
tokens = text
|
|
||||||
else:
|
|
||||||
tokens = text.split(self.delimiter)
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
def tokenize(self, line, add_eos=False, add_double_eos=False):
|
||||||
line = self._clean_text(line)
|
|
||||||
line = line.strip()
|
line = line.strip()
|
||||||
|
# convert to lower case
|
||||||
|
if self.lower_case:
|
||||||
|
line = line.lower()
|
||||||
|
|
||||||
symbols = self.whitespace_tokenize(line)
|
# empty delimiter '' will evaluate False
|
||||||
|
if self.delimiter == '':
|
||||||
split_symbols = []
|
symbols = line
|
||||||
for symbol in symbols:
|
else:
|
||||||
if self.lower_case and symbol not in self.never_split:
|
symbols = line.split(self.delimiter)
|
||||||
symbol = symbol.lower()
|
|
||||||
symbol = self._run_strip_accents(symbol)
|
|
||||||
split_symbols.extend(self._run_split_on_punc(symbol))
|
|
||||||
|
|
||||||
if add_double_eos: # lm1b
|
if add_double_eos: # lm1b
|
||||||
return ['<S>'] + split_symbols + ['<S>']
|
return ['<S>'] + symbols + ['<S>']
|
||||||
elif add_eos:
|
elif add_eos:
|
||||||
return split_symbols + ['<eos>']
|
return symbols + ['<eos>']
|
||||||
else:
|
else:
|
||||||
return split_symbols
|
return symbols
|
||||||
|
|
||||||
|
|
||||||
class LMOrderedIterator(object):
|
class LMOrderedIterator(object):
|
||||||
@ -631,42 +581,3 @@ def get_lm_corpus(datadir, dataset):
|
|||||||
torch.save(corpus, fn)
|
torch.save(corpus, fn)
|
||||||
|
|
||||||
return corpus
|
return corpus
|
||||||
|
|
||||||
def _is_whitespace(char):
|
|
||||||
"""Checks whether `chars` is a whitespace character."""
|
|
||||||
# \t, \n, and \r are technically contorl characters but we treat them
|
|
||||||
# as whitespace since they are generally considered as such.
|
|
||||||
if char == " " or char == "\t" or char == "\n" or char == "\r":
|
|
||||||
return True
|
|
||||||
cat = unicodedata.category(char)
|
|
||||||
if cat == "Zs":
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_control(char):
|
|
||||||
"""Checks whether `chars` is a control character."""
|
|
||||||
# These are technically control characters but we count them as whitespace
|
|
||||||
# characters.
|
|
||||||
if char == "\t" or char == "\n" or char == "\r":
|
|
||||||
return False
|
|
||||||
cat = unicodedata.category(char)
|
|
||||||
if cat.startswith("C"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_punctuation(char):
|
|
||||||
"""Checks whether `chars` is a punctuation character."""
|
|
||||||
cp = ord(char)
|
|
||||||
# We treat all non-letter/number ASCII as punctuation.
|
|
||||||
# Characters such as "^", "$", and "`" are not in the Unicode
|
|
||||||
# Punctuation class but we treat them as punctuation anyways, for
|
|
||||||
# consistency.
|
|
||||||
if ((cp >= 33 and cp <= 47) or (cp >= 58 and cp <= 64) or
|
|
||||||
(cp >= 91 and cp <= 96) or (cp >= 123 and cp <= 126)):
|
|
||||||
return True
|
|
||||||
cat = unicodedata.category(char)
|
|
||||||
if cat.startswith("P"):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
Loading…
Reference in New Issue
Block a user