fixing tests

This commit is contained in:
thomwolf 2019-04-15 12:55:38 +02:00
parent 870b734bfd
commit e8568a3b17
4 changed files with 51 additions and 14 deletions

View File

@ -45,6 +45,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
@lru_cache()
def bytes_to_unicode():
@ -97,6 +98,11 @@ class GPT2Tokenizer(object):
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
@ -125,7 +131,11 @@ class GPT2Tokenizer(object):
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
@ -194,7 +204,11 @@ class GPT2Tokenizer(object):
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
@ -203,9 +217,14 @@ class GPT2Tokenizer(object):
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(bpe_tokens + u'\n')
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
return vocab_file, merge_file
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
writer.write(token + u'\n')
return vocab_file, merge_file, special_tokens_file
def encode(self, text):
bpe_tokens = []

View File

@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
}
VOCAB_NAME = 'vocab.json'
MERGES_NAME = 'merges.txt'
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
def get_pairs(word):
"""
@ -89,6 +90,11 @@ class OpenAIGPTTokenizer(object):
else:
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
if not os.path.exists(special_tokens_file):
special_tokens_file = None
else:
logger.info("loading special tokens file {}".format(special_tokens_file))
# redirect to the cache, if necessary
try:
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
@ -117,7 +123,11 @@ class OpenAIGPTTokenizer(object):
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
# Instantiate tokenizer.
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
if special_tokens_file and 'special_tokens' not in kwargs:
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
else:
special_tokens = kwargs.pop('special_tokens', [])
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
return tokenizer
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
@ -269,7 +279,11 @@ class OpenAIGPTTokenizer(object):
return
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
merge_file = os.path.join(vocab_path, MERGES_NAME)
json.dump(self.encoder, vocab_file)
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
@ -278,6 +292,11 @@ class OpenAIGPTTokenizer(object):
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(bpe_tokens + u'\n')
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
return vocab_file, merge_file
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
writer.write(token + u'\n')
return vocab_file, merge_file, special_tokens_file

View File

@ -52,7 +52,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)

View File

@ -35,7 +35,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.build_vocab()
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
@ -45,7 +45,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
@ -56,15 +56,14 @@ class TransfoXLTokenizationTest(unittest.TestCase):
tokenizer = TransfoXLTokenizer(lower_case=True)
self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["hello", "!", "how", "are", "you", "?"])
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
def test_full_tokenizer_no_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=False)
self.assertListEqual(
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
["HeLLo", "!", "how", "Are", "yoU", "?"])