switching to moses tokenizer

This commit is contained in:
thomwolf 2019-10-10 10:11:16 +02:00
parent 036483fae5
commit 43a237f15e

View File

@ -22,8 +22,9 @@ import os
import regex as re
from io import open
from .tokenization_bert import BasicTokenizer
import sacremoses as sm
from .tokenization_xlm import replace_unicode_punct, remove_non_printing_char
from .tokenization_utils import PreTrainedTokenizer
logger = logging.getLogger(__name__)
@ -48,39 +49,11 @@ PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'ctrl': 256,
}
def text_standardize(text):
"""
fixes some issues the spacy tokenizer had on books corpus
also does some whitespace standardization
"""
text = text.replace('', '-')
text = text.replace('', '-')
text = text.replace('', '-')
text = text.replace('', '...')
text = text.replace('´', "'")
text = re.sub(r'''(-+|~+|!+|"+|;+|\?+|\++|,+|\)+|\(+|\\+|\/+|\*+|\[+|\]+|}+|{+|\|+|_+)''', r' \1 ', text)
text = re.sub(r'\s*\n\s*', ' \n ', text)
text = re.sub(r'[^\S\n]+', ' ', text)
return text.strip()
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
# pairs = []
# prev_char = word[0]
# for i, char in enumerate(word[1:]):
# #_i = i + 1
# #if word[_i+1:] == tuple('</w>'):
# # pairs.append((prev_char, char+'</w>'))
# # break
# #else:
# if True:
# pairs.append((prev_char, char))
# prev_char = char
pairs = set()
prev_char = word[0]
for char in word[1:]:
@ -108,6 +81,9 @@ class CTRLTokenizer(PreTrainedTokenizer):
self.max_len_single_sentence = self.max_len # no default special tokens - you can update this value if you add special tokens
self.max_len_sentences_pair = self.max_len # no default special tokens - you can update this value if you add special tokens
self.punct_normalizer = sm.MosesPunctNormalizer(lang='en')
self.moses_tokenizer = sm.MosesTokenizer(lang='en')
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v:k for k,v in self.encoder.items()}
merges = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
@ -162,11 +138,23 @@ class CTRLTokenizer(PreTrainedTokenizer):
self.cache[token] = word
return word
def _tokenize(self, text):
def moses_pipeline(self, text):
text = replace_unicode_punct(text)
text = self.punct_normalizer.normalize(text)
text = remove_non_printing_char(text)
return text
def _tokenize(self, text, bypass_tokenizer=False):
""" Tokenize a string.
"""
split_tokens = []
text = text.split(' ')
if bypass_tokenizer:
text = text.split()
else:
text = self.moses_pipeline(text)
text = self.moses_tokenizer.tokenize(text, return_str=False, escape=False)
for token in text:
split_tokens.extend([t for t in self.bpe(token).split(' ')])
return split_tokens