Roberta tokenization + fixed tests (py3 + py2).

This commit is contained in:
LysandreJik 2019-08-09 15:02:13 -04:00
parent 14e970c271
commit 75d5f98fd2
3 changed files with 138 additions and 224 deletions

View File

@ -157,42 +157,6 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
inputs_dict = {'input_ids': input_ids, 'token_type_ids': token_type_ids, 'attention_mask': input_mask}
return config, inputs_dict
def test_inference_masked_lm(self):
model = RobertaForMaskedLM.from_pretrained('roberta-base')
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 50265))
self.assertEqual(
output.shape,
expected_shape
)
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[33.8843, -4.3107, 22.7779],
[4.6533, -2.8099, 13.6252],
[1.8222, -3.6898, 8.8600]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
# @pytest.mark.slow
def test_inference_no_head(self):
model = RobertaModel.from_pretrained('roberta-base')
input_ids = torch.tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[-0.0231, 0.0782, 0.0074],
[-0.1854, 0.0539, -0.0174],
[0.0548, 0.0799, 0.1687]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
def setUp(self):
self.model_tester = RobertaModelTest.RobertaModelTester(self)
self.config_tester = ConfigTester(self, config_class=RobertaConfig, hidden_size=37)
@ -220,7 +184,7 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
class RobertaModelIntegrationTest(unittest.TestCase):
# @pytest.mark.slow
@pytest.mark.slow
def test_inference_masked_lm(self):
model = RobertaForMaskedLM.from_pretrained('roberta-base')
@ -241,7 +205,7 @@ class RobertaModelIntegrationTest(unittest.TestCase):
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
# @pytest.mark.slow
@pytest.mark.slow
def test_inference_no_head(self):
model = RobertaModel.from_pretrained('roberta-base')

View File

@ -18,8 +18,7 @@ import os
import json
import unittest
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, DICT_FILES_NAMES
from pytorch_transformers.tokenization_gpt2 import GPT2Tokenizer, VOCAB_FILES_NAMES
from pytorch_transformers.tokenization_roberta import RobertaTokenizer, VOCAB_FILES_NAMES
from .tokenization_tests_commons import CommonTestCases
@ -45,8 +44,7 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
fp.write("\n".join(merges))
def get_tokenizer(self):
bpe_tokenizer = GPT2Tokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
return RobertaTokenizer.from_pretrained("roberta-base", bpe_tokenizer=bpe_tokenizer)
return RobertaTokenizer.from_pretrained(self.tmpdirname, **self.special_tokens_map)
def get_input_output_texts(self):
input_text = u"lower newer"
@ -54,15 +52,14 @@ class RobertaTokenizationTest(CommonTestCases.CommonTokenizerTester):
return input_text, output_text
def test_full_tokenizer(self):
tokenizer = self.get_tokenizer()
tokenizer = RobertaTokenizer(self.vocab_file, self.merges_file, **self.special_tokens_map)
text = "lower"
bpe_tokens = ["low", "er"]
tokens = tokenizer.tokenize(text)
self.assertListEqual(tokens, bpe_tokens)
input_tokens = tokens + [tokenizer.unk_token]
input_bpe_tokens = [0, 4, 12, 176, 2]
tokenizer.convert_tokens_to_ids(input_tokens)
input_bpe_tokens = [13, 12, 17]
self.assertListEqual(
tokenizer.convert_tokens_to_ids(input_tokens), input_bpe_tokens)

View File

@ -12,229 +12,182 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RoBERTa."""
"""Tokenization classes for OpenAI GPT."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import sys
import json
import logging
import re
from io import open
import six
import os
import regex as re
from io import open
from .tokenization_gpt2 import bytes_to_unicode, get_pairs
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
try:
from functools import lru_cache
except ImportError:
# Just a dummy decorator to get the checks to run on python2
# because honestly I don't want to support a byte-level unicode BPE tokenizer on python 2 right now.
def lru_cache():
return lambda func: func
logger = logging.getLogger(__name__)
DICT_FILES_NAMES = {
'dict_file': 'dict.txt',
VOCAB_FILES_NAMES = {
'vocab_file': 'vocab.json',
'merges_file': 'merges.txt',
}
PRETRAINED_DICT_FILES_MAP = {
'dict_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
},
PRETRAINED_VOCAB_FILES_MAP = {
'vocab_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-vocab.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-vocab.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-vocab.json",
},
'merges_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-merges.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-merges.txt",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'roberta-base': 512,
'roberta-large': 512,
'roberta-large-mnli': 512,
'roberta-base': 1024,
'roberta-large': 1024,
'roberta-large-mnli': 1024,
}
SPACE_NORMALIZER = re.compile(r"\s+")
def tokenize_line(line):
line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip()
return line.split()
class Dictionary(object):
"""
A mapping from symbols to consecutive integers
From Facebook's fairseq.
"""
def __init__(
self,
pad='<pad>',
eos='</s>',
unk='<unk>',
bos='<s>',
extra_special_symbols=None,
):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
@classmethod
def load(cls, f, ignore_utf_errors=False):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f, ignore_utf_errors)
return d
def add_from_file(self, f, ignore_utf_errors=False):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, six.string_types):
try:
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
self.add_from_file(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f))
return
lines = f.read().splitlines()
for line in lines:
idx = line.rfind(' ')
if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx]
count = int(line[idx + 1:])
self.indices[word] = len(self.symbols)
self.symbols.append(word)
self.count.append(count)
def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = [0] * (nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
class RobertaTokenizer(PreTrainedTokenizer):
"""
RoBERTa tokenizer. Peculiarities:
- GPT-2 tokenizer with a different integer mapping on top.
GPT-2 BPE tokenizer. Peculiarities:
- Byte-level BPE
"""
vocab_files_names = DICT_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_DICT_FILES_MAP
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, dict_file, bpe_tokenizer=None, bos_token="<s>", eos_token="</s>", sep_token="</s>", cls_token="<s>",
unk_token="<unk>", **kwargs):
super(RobertaTokenizer, self).__init__(cls_token=bos_token, sep_token=eos_token, eos_token=eos_token,
unk_token=unk_token, **kwargs)
def __init__(self, vocab_file, merges_file, errors='replace', bos_token="<s>", eos_token="</s>", sep_token="</s>",
cls_token="<s>", unk_token="<unk>", **kwargs):
super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, unk_token=unk_token,
sep_token=sep_token, cls_token=cls_token, **kwargs)
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained("gpt2") if bpe_tokenizer is None else bpe_tokenizer
self.dictionary = Dictionary.load(dict_file)
self.encoder = json.load(open(vocab_file, encoding="utf-8"))
self.decoder = {v: k for k, v in self.encoder.items()}
self.errors = errors # how to handle errors in decoding
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
bpe_data = open(merges_file, encoding='utf-8').read().split('\n')[1:-1]
bpe_merges = [tuple(merge.split()) for merge in bpe_data]
self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges))))
self.cache = {}
# Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions
self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
@property
def vocab_size(self):
return len(self.dictionary.indices)
return len(self.encoder)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token)
pairs = get_pairs(word)
if not pairs:
return token
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def _tokenize(self, text):
""" Use GPT-2 Tokenizer """
return self.gpt2_tokenizer._tokenize(text)
""" Tokenize a string. """
bpe_tokens = []
for token in re.findall(self.pat, text):
if sys.version_info[0] == 2:
token = ''.join(self.byte_encoder[ord(b)] for b in token)
else:
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(bpe_token for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def _convert_token_to_id(self, token):
if self.dictionary.index(token) != 3:
return self.dictionary.index(token)
return self.dictionary.index(str(self.gpt2_tokenizer.convert_tokens_to_ids(token)))
""" Converts a token (str/unicode) in an id using the vocab. """
return self.encoder.get(token, self.encoder.get(self.unk_token))
def _convert_id_to_token(self, index):
symbol = self.dictionary[index]
try:
idx = int(symbol)
return self.gpt2_tokenizer._convert_id_to_token(idx)
except ValueError:
return symbol
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return self.decoder.get(index)
def convert_tokens_to_string(self, tokens):
return self.gpt2_tokenizer.convert_tokens_to_string(tokens)
""" Converts a sequence of tokens (string) in a single string. """
text = ''.join(tokens)
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors=self.errors)
return text
def convert_tokens_to_ids(self, tokens, no_sep_cls_tokens=False):
cls = [self._convert_token_to_id(self.cls_token)]
tokens = super().convert_tokens_to_ids(tokens)
def add_special_tokens_single_sentence(self, token_ids):
return [self._convert_token_to_id(self.cls_token)] + token_ids + [self._convert_token_to_id(self.sep_token)]
def add_special_tokens_sentences_pair(self, *token_ids):
sep = [self._convert_token_to_id(self.sep_token)]
return (cls + tokens + sep) if (isinstance(tokens, list) and not no_sep_cls_tokens) else tokens
def convert_ids_to_tokens(self, ids, skip_special_tokens=False):
return super().convert_ids_to_tokens(ids, skip_special_tokens=skip_special_tokens)[1:-1]
cls = [self._convert_token_to_id(self.cls_token)]
return cls + token_ids[0] + sep + sep + token_ids[1] + sep
def save_vocabulary(self, save_directory):
"""Save the tokenizer vocabulary and merge files to a directory."""
if not os.path.isdir(save_directory):
logger.error("Vocabulary path ({}) should be a directory".format(save_directory))
return
dict_file = os.path.join(save_directory, DICT_FILES_NAMES['dict_file'])
vocab_file = os.path.join(save_directory, VOCAB_FILES_NAMES['vocab_file'])
merge_file = os.path.join(save_directory, VOCAB_FILES_NAMES['merges_file'])
with open(dict_file, 'w', encoding='utf-8') as f:
for i in range(self.dictionary.nspecial, len(self.dictionary.count)):
f.write(f"{list(self.dictionary.indices.keys())[i]} {self.dictionary.count[i]}\n")
with open(vocab_file, 'w', encoding='utf-8') as f:
f.write(json.dumps(self.encoder, ensure_ascii=False))
vocab_files = self.gpt2_tokenizer.save_pretrained(save_directory)
index = 0
with open(merge_file, "w", encoding="utf-8") as writer:
writer.write(u'#version: 0.2\n')
for bpe_tokens, token_index in sorted(self.bpe_ranks.items(), key=lambda kv: kv[1]):
if index != token_index:
logger.warning("Saving vocabulary to {}: BPE merge indices are not consecutive."
" Please check that the tokenizer is not corrupted!".format(merge_file))
index = token_index
writer.write(' '.join(bpe_tokens) + u'\n')
index += 1
return vocab_files + (dict_file,)
return vocab_file, merge_file