standardizing tokenizers API and adding tests

This commit is contained in:
thomwolf 2019-07-05 11:20:27 +02:00
parent c0239e09e6
commit e75c3f70aa
15 changed files with 150 additions and 107 deletions

View File

@ -598,3 +598,9 @@ def prune_layer(layer, index, dim=None):
return prune_conv1d_layer(layer, index, dim=1 if dim is None else dim)
else:
raise ValueError("Can't prune layer of class {}".format(layer.__class__))
def clean_up_tokenization(out_string):
out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
return out_string

View File

@ -48,6 +48,7 @@ PRETRAINED_MODEL_ARCHIVE_MAP = {
'bert-large-cased-whole-word-masking-finetuned-squad': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-cased-whole-word-masking-finetuned-squad-pytorch_model.bin",
'bert-base-cased-finetuned-mrpc': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-finetuned-mrpc-pytorch_model.bin",
}
PRETRAINED_CONFIG_ARCHIVE_MAP = {
'bert-base-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased-config.json",
'bert-large-uncased': "https://s3.amazonaws.com/models.huggingface.co/bert/bert-large-uncased-config.json",

View File

@ -26,6 +26,7 @@ from pytorch_pretrained_bert.tokenization_bert import (BasicTokenizer,
_is_control, _is_punctuation,
_is_whitespace, PRETRAINED_VOCAB_ARCHIVE_MAP)
from .tokenization_tests_commons import create_and_check_tokenizer_commons
class TokenizationTest(unittest.TestCase):
@ -36,28 +37,18 @@ class TokenizationTest(unittest.TestCase):
]
with open("/tmp/bert_tokenizer_test.txt", "w", encoding='utf-8') as vocab_writer:
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, BertTokenizer, vocab_file)
tokenizer = BertTokenizer(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer = tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"UNwant\u00E9d,running")
self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [7, 4, 5, 10, 8, 9])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"

View File

@ -22,6 +22,7 @@ import pytest
from pytorch_pretrained_bert.tokenization_gpt2 import GPT2Tokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from .tokenization_tests_commons import create_and_check_tokenizer_commons
class GPT2TokenizationTest(unittest.TestCase):
@ -39,10 +40,9 @@ class GPT2TokenizationTest(unittest.TestCase):
fp.write("\n".join(merges))
merges_file = fp.name
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
create_and_check_tokenizer_commons(self, GPT2Tokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
tokenizer = GPT2Tokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
@ -53,17 +53,8 @@ class GPT2TokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer_2 = GPT2Tokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
os.remove(special_tokens_file)
self.assertListEqual(
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
# @pytest.mark.slow
def test_tokenizer_from_pretrained(self):

View File

@ -22,6 +22,8 @@ import pytest
from pytorch_pretrained_bert.tokenization_openai import OpenAIGPTTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from.tokenization_tests_commons import create_and_check_tokenizer_commons
class OpenAIGPTTokenizationTest(unittest.TestCase):
@ -40,6 +42,8 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, OpenAIGPTTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
tokenizer = OpenAIGPTTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
@ -54,18 +58,6 @@ class OpenAIGPTTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer_2 = OpenAIGPTTokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
os.remove(special_tokens_file)
self.assertListEqual(
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"

View File

@ -0,0 +1,81 @@
# coding=utf-8
# Copyright 2019 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
from io import open
if sys.version_info[0] == 3:
unicode = str
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
def create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
before_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
vocab_path="/tmp/"
output_files = tokenizer.save_vocabulary(vocab_path=vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path)
for f in output_files:
os.remove(f)
after_tokens = tokenizer.encode(u"He is very happy, UNwant\u00E9d,running")
tester.assertListEqual(before_tokens, after_tokens)
def create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
text = "Munich and Berlin are nice cities"
filename = u"/tmp/tokenizer.bin"
subwords = tokenizer.tokenize(text)
pickle.dump(tokenizer, open(filename, "wb"))
tokenizer_new = pickle.load(open(filename, "rb"))
subwords_loaded = tokenizer_new.tokenize(text)
tester.assertListEqual(subwords, subwords_loaded)
def create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs):
tokenizer = tokenizer_class(*inputs, **kwargs)
text = u"He is very happy, UNwant\u00E9d,running"
tokens = tokenizer.tokenize(text)
ids = tokenizer.convert_tokens_to_ids(tokens)
ids_2 = tokenizer.encode(text)
tester.assertListEqual(ids, ids_2)
tokens_2 = tokenizer.convert_ids_to_tokens(ids)
text_2 = tokenizer.decode(ids)
tester.assertNotEqual(len(tokens_2), 0)
tester.assertIsInstance(text_2, (str, unicode))
def create_and_check_tokenizer_commons(tester, tokenizer_class, *inputs, **kwargs):
create_and_check_required_methods_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_save_and_load_tokenizer(tester, tokenizer_class, *inputs, **kwargs)
create_and_check_pickle_tokenizer(tester, tokenizer_class, *inputs, **kwargs)

View File

@ -22,6 +22,7 @@ import pytest
from pytorch_pretrained_bert.tokenization_transfo_xl import TransfoXLTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from.tokenization_tests_commons import create_and_check_tokenizer_commons
class TransfoXLTokenizationTest(unittest.TestCase):
@ -33,8 +34,9 @@ class TransfoXLTokenizationTest(unittest.TestCase):
vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
vocab_file = vocab_writer.name
create_and_check_tokenizer_commons(self, TransfoXLTokenizer, vocab_file=vocab_file, lower_case=True)
tokenizer = TransfoXLTokenizer(vocab_file=vocab_file, lower_case=True)
tokenizer.build_vocab()
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
@ -43,17 +45,6 @@ class TransfoXLTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
vocab_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer = tokenizer.from_pretrained(vocab_file)
os.remove(vocab_file)
tokens = tokenizer.tokenize(u"<unk> UNwanted , running")
self.assertListEqual(tokens, ["<unk>", "unwanted", ",", "running"])
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [0, 4, 8, 7])
def test_full_tokenizer_lower(self):
tokenizer = TransfoXLTokenizer(lower_case=True)

View File

@ -22,6 +22,7 @@ import pytest
from pytorch_pretrained_bert.tokenization_xlm import XLMTokenizer, PRETRAINED_VOCAB_ARCHIVE_MAP
from.tokenization_tests_commons import create_and_check_tokenizer_commons
class XLMTokenizationTest(unittest.TestCase):
@ -40,6 +41,8 @@ class XLMTokenizationTest(unittest.TestCase):
fp.write("\n".join(merges))
merges_file = fp.name
create_and_check_tokenizer_commons(self, XLMTokenizer, vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
tokenizer = XLMTokenizer(vocab_file, merges_file, special_tokens=["<unk>", "<pad>"])
os.remove(vocab_file)
os.remove(merges_file)
@ -54,18 +57,6 @@ class XLMTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)
vocab_file, merges_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path="/tmp/")
tokenizer_2 = XLMTokenizer.from_pretrained("/tmp/")
os.remove(vocab_file)
os.remove(merges_file)
os.remove(special_tokens_file)
self.assertListEqual(
[tokenizer.encoder, tokenizer.decoder, tokenizer.bpe_ranks,
tokenizer.special_tokens, tokenizer.special_tokens_decoder],
[tokenizer_2.encoder, tokenizer_2.decoder, tokenizer_2.bpe_ranks,
tokenizer_2.special_tokens, tokenizer_2.special_tokens_decoder])
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"

View File

@ -15,28 +15,25 @@
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import sys
import unittest
from io import open
import shutil
import pytest
if sys.version_info[0] == 2:
import cPickle as pickle
else:
import pickle
from pytorch_pretrained_bert.tokenization_xlnet import (XLNetTokenizer,
PRETRAINED_VOCAB_ARCHIVE_MAP,
SPIECE_UNDERLINE)
from.tokenization_tests_commons import create_and_check_tokenizer_commons
SAMPLE_VOCAB = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'fixtures/test_sentencepiece.model')
class XLNetTokenizationTest(unittest.TestCase):
def test_full_tokenizer(self):
tokenizer = XLNetTokenizer(SAMPLE_VOCAB)
create_and_check_tokenizer_commons(self, XLNetTokenizer, SAMPLE_VOCAB)
tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
tokens = tokenizer.tokenize(u'This is a test')
self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])
@ -44,11 +41,6 @@ class XLNetTokenizationTest(unittest.TestCase):
self.assertListEqual(
tokenizer.convert_tokens_to_ids(tokens), [285, 46, 10, 170, 382])
vocab_path = u"/tmp/"
vocab_file, special_tokens_file = tokenizer.save_vocabulary(vocab_path)
tokenizer = tokenizer.from_pretrained(vocab_path,
keep_accents=True)
tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b',
u'or', u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
@ -68,22 +60,6 @@ class XLNetTokenizationTest(unittest.TestCase):
SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
u'<unk>', u'.'])
text = "Munich and Berlin are nice cities"
filename = u"/tmp/tokenizer.bin"
subwords = tokenizer.tokenize(text)
pickle.dump(tokenizer, open(filename, "wb"))
tokenizer_new = pickle.load(open(filename, "rb"))
subwords_loaded = tokenizer_new.tokenize(text)
self.assertListEqual(subwords, subwords_loaded)
os.remove(filename)
os.remove(vocab_file)
os.remove(special_tokens_file)
@pytest.mark.slow
def test_tokenizer_from_pretrained(self):
cache_dir = "/tmp/pytorch_pretrained_bert_test/"

View File

@ -23,6 +23,7 @@ import unicodedata
from io import open
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
logger = logging.getLogger(__name__)
@ -185,6 +186,19 @@ class BertTokenizer(object):
tokens.append(self.ids_to_tokens[i])
return tokens
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, token_ids, clean_up_tokenization_spaces=True):
"""Converts a sequence of ids in a string."""
tokens = self.convert_ids_to_tokens(token_ids)
out_string = ''.join(tokens).replace(' ##', '').strip()
if clean_up_tokenization_spaces:
for special_tok in (self.UNK_TOKEN, self.SEP_TOKEN, self.PAD_TOKEN, self.CLS_TOKEN, self.MASK_TOKEN):
out_string = out_string.replace(special_tok, '')
out_string = clean_up_tokenization(out_string)
return out_string
def save_vocabulary(self, vocab_path):
"""Save the tokenizer vocabulary to a directory or file."""
index = 0
@ -198,7 +212,7 @@ class BertTokenizer(object):
index = token_index
writer.write(token + u'\n')
index += 1
return vocab_file
return (vocab_file,)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, cache_dir=None, *inputs, **kwargs):

View File

@ -23,6 +23,8 @@ import os
import regex as re
from io import open
from .model_utils import clean_up_tokenization
try:
from functools import lru_cache
except ImportError:
@ -275,9 +277,7 @@ class GPT2Tokenizer(object):
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
if clean_up_tokenization_spaces:
text = text.replace('<unk>', '')
text = text.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
text = clean_up_tokenization(text)
return text
def save_vocabulary(self, vocab_path):

View File

@ -26,6 +26,7 @@ from io import open
from tqdm import tqdm
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__)
@ -277,9 +278,7 @@ class OpenAIGPTTokenizer(object):
out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
out_string = clean_up_tokenization(out_string)
return out_string
def save_vocabulary(self, vocab_path):

View File

@ -31,6 +31,7 @@ import torch
import numpy as np
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
if sys.version_info[0] == 2:
import cPickle as pickle
@ -109,6 +110,9 @@ class TransfoXLTokenizer(object):
self.vocab_file = vocab_file
self.never_split = never_split
if vocab_file is not None:
self.build_vocab()
def count_file(self, path, verbose=False, add_eos=False):
if verbose: print('counting file {} ...'.format(path))
assert os.path.exists(path)
@ -155,7 +159,7 @@ class TransfoXLTokenizer(object):
if os.path.isdir(vocab_path):
vocab_file = os.path.join(vocab_path, VOCAB_NAME)
torch.save(self.__dict__, vocab_file)
return vocab_file
return (vocab_file,)
def build_vocab(self):
if self.vocab_file:
@ -251,12 +255,20 @@ class TransfoXLTokenizer(object):
def convert_to_tensor(self, symbols):
return torch.LongTensor(self.convert_tokens_to_ids(symbols))
def decode(self, indices, exclude=None):
def encode(self, text):
return self.convert_tokens_to_ids(self.tokenize(text))
def decode(self, indices, exclude=None, clean_up_tokenization_spaces=True):
"""Converts a sequence of indices in a string."""
if exclude is None:
return ' '.join([self.get_sym(idx) for idx in indices])
out_string = ' '.join([self.get_sym(idx) for idx in indices])
else:
return ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
out_string = ' '.join([self.get_sym(idx) for idx in indices if idx not in exclude])
if clean_up_tokenization_spaces:
out_string = clean_up_tokenization(out_string)
return out_string
def __len__(self):
return len(self.idx2sym)

View File

@ -26,6 +26,7 @@ from io import open
from tqdm import tqdm
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
from .tokenization_bert import BasicTokenizer
logger = logging.getLogger(__name__)
@ -285,9 +286,7 @@ class XLMTokenizer(object):
out_string = ''.join(tokens).replace('</w>', ' ').strip()
if clean_up_tokenization_spaces:
out_string = out_string.replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
out_string = clean_up_tokenization(out_string)
return out_string
def save_vocabulary(self, vocab_path):

View File

@ -27,6 +27,7 @@ import unicodedata
import six
from .file_utils import cached_path
from .model_utils import clean_up_tokenization
logger = logging.getLogger(__name__)
@ -316,9 +317,7 @@ class XLNetTokenizer(object):
out_string = ''.join(tokens)
if clean_up_tokenization_spaces:
out_string = out_string.strip().replace('<unk>', '')
out_string = out_string.replace(' .', '.').replace(' ?', '?').replace(' !', '!').replace(' ,', ','
).replace(" ' ", "'").replace(" n't", "n't").replace(" 'm", "'m").replace(" do not", " don't"
).replace(" 's", "'s").replace(" 've", "'ve").replace(" 're", "'re")
out_string = clean_up_tokenization(out_string)
return out_string
def save_vocabulary(self, vocab_path):