mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fixing tests
This commit is contained in:
parent
870b734bfd
commit
e8568a3b17
@ -45,6 +45,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
@ -97,6 +98,11 @@ class GPT2Tokenizer(object):
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@ -125,7 +131,11 @@ class GPT2Tokenizer(object):
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, errors='replace', max_len=None):
|
||||
@ -194,7 +204,11 @@ class GPT2Tokenizer(object):
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
json.dump(self.encoder, vocab_file)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
@ -203,9 +217,14 @@ class GPT2Tokenizer(object):
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(bpe_tokens + u'\n')
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
return vocab_file, merge_file
|
||||
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
|
||||
writer.write(token + u'\n')
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
|
@ -41,6 +41,7 @@ PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP = {
|
||||
}
|
||||
VOCAB_NAME = 'vocab.json'
|
||||
MERGES_NAME = 'merges.txt'
|
||||
SPECIAL_TOKENS_NAME = 'special_tokens.txt'
|
||||
|
||||
def get_pairs(word):
|
||||
"""
|
||||
@ -89,6 +90,11 @@ class OpenAIGPTTokenizer(object):
|
||||
else:
|
||||
vocab_file = os.path.join(pretrained_model_name_or_path, VOCAB_NAME)
|
||||
merges_file = os.path.join(pretrained_model_name_or_path, MERGES_NAME)
|
||||
special_tokens_file = os.path.join(pretrained_model_name_or_path, SPECIAL_TOKENS_NAME)
|
||||
if not os.path.exists(special_tokens_file):
|
||||
special_tokens_file = None
|
||||
else:
|
||||
logger.info("loading special tokens file {}".format(special_tokens_file))
|
||||
# redirect to the cache, if necessary
|
||||
try:
|
||||
resolved_vocab_file = cached_path(vocab_file, cache_dir=cache_dir)
|
||||
@ -117,7 +123,11 @@ class OpenAIGPTTokenizer(object):
|
||||
max_len = PRETRAINED_VOCAB_POSITIONAL_EMBEDDINGS_SIZE_MAP[pretrained_model_name_or_path]
|
||||
kwargs['max_len'] = min(kwargs.get('max_len', int(1e12)), max_len)
|
||||
# Instantiate tokenizer.
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, *inputs, **kwargs)
|
||||
if special_tokens_file and 'special_tokens' not in kwargs:
|
||||
special_tokens = open(special_tokens_file, encoding='utf-8').read().split('\n')[:-1]
|
||||
else:
|
||||
special_tokens = kwargs.pop('special_tokens', [])
|
||||
tokenizer = cls(resolved_vocab_file, resolved_merges_file, special_tokens=special_tokens, *inputs, **kwargs)
|
||||
return tokenizer
|
||||
|
||||
def __init__(self, vocab_file, merges_file, special_tokens=None, max_len=None):
|
||||
@ -269,7 +279,11 @@ class OpenAIGPTTokenizer(object):
|
||||
return
|
||||
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
|
||||
merge_file = os.path.join(vocab_path, MERGES_NAME)
|
||||
json.dump(self.encoder, vocab_file)
|
||||
special_tokens_file = os.path.join(vocab_path, SPECIAL_TOKENS_NAME)
|
||||
|
||||
with open(vocab_file, 'w', encoding='utf-8') as f:
|
||||
f.write(json.dumps(self.encoder, ensure_ascii=False))
|
||||
|
||||
index = 0
|
||||
with open(merge_file, "w", encoding="utf-8") as writer:
|
||||
writer.write(u'#version: 0.2\n')
|
||||
@ -278,6 +292,11 @@ class OpenAIGPTTokenizer(object):
|
||||
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
|
||||
" Please check that the tokenizer is not corrupted!".format(merge_file))
|
||||
index = token_index
|
||||
writer.write(bpe_tokens + u'\n')
|
||||
writer.write(' '.join(bpe_tokens) + u'\n')
|
||||
index += 1
|
||||
return vocab_file, merge_file
|
||||
|
||||
with open(special_tokens_file, 'w', encoding='utf-8') as writer:
|
||||
for token in sorted(self.special_tokens.keys(), key=lambda kv: kv[1]):
|
||||
writer.write(token + u'\n')
|
||||
|
||||
return vocab_file, merge_file, special_tokens_file
|
||||
|
@ -52,7 +52,7 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
|
||||
self.assertListEqual(
|
||||
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
|
||||
|
||||
vocab_file, merges_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
|
||||
tokenizer.from_pretrained("/tmp/")
|
||||
os.remove(vocab_file)
|
||||
os.remove(merges_file)
|
||||
|
@ -35,7 +35,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
tokenizer.build_vocab()
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
@ -45,7 +45,7 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
tokenizer.from_pretrained(vocab_file)
|
||||
os.remove(vocab_file)
|
||||
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwant\u00E9d,running")
|
||||
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
|
||||
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
|
||||
|
||||
self.assertListEqual(
|
||||
@ -56,15 +56,14 @@ class TransfoXLTokenizationTest(unittest.TestCase):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=True)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
|
||||
["hello", "!", "how", "are", "you", "?"])
|
||||
self.assertListEqual(tokenizer.tokenize(u"H\u00E9llo"), ["hello"])
|
||||
|
||||
def test_full_tokenizer_no_lower(self):
|
||||
tokenizer = TransfoXLTokenizer(lower_case=False)
|
||||
|
||||
self.assertListEqual(
|
||||
tokenizer.tokenize(u" \tHeLLo!how \n Are yoU? "),
|
||||
tokenizer.tokenize(u" \tHeLLo ! how \n Are yoU ? "),
|
||||
["HeLLo", "!", "how", "Are", "yoU", "?"])
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user