[RoBERTa] model conversion, inference, tests 🔥

This commit is contained in:
Julien Chaumond 2019-08-04 21:39:21 -04:00
parent 44dd941efb
commit 05c083520a
6 changed files with 622 additions and 0 deletions

View File

@ -12,6 +12,7 @@ The library currently contains PyTorch implementations, pre-trained model weight
4. **[Transformer-XL](https://github.com/kimiyoung/transformer-xl)** (from Google/CMU) released with the paper [Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context](https://arxiv.org/abs/1901.02860) by Zihang Dai*, Zhilin Yang*, Yiming Yang, Jaime Carbonell, Quoc V. Le, Ruslan Salakhutdinov.
5. **[XLNet](https://github.com/zihangdai/xlnet/)** (from Google/CMU) released with the paper [XLNet: Generalized Autoregressive Pretraining for Language Understanding](https://arxiv.org/abs/1906.08237) by Zhilin Yang*, Zihang Dai*, Yiming Yang, Jaime Carbonell, Ruslan Salakhutdinov, Quoc V. Le.
6. **[XLM](https://github.com/facebookresearch/XLM/)** (from Facebook) released together with the paper [Cross-lingual Language Model Pretraining](https://arxiv.org/abs/1901.07291) by Guillaume Lample and Alexis Conneau.
7. **[RoBERTa](https://github.com/pytorch/fairseq/tree/master/examples/roberta)** (from Facebook), a [Robustly Optimized BERT Pretraining Approach](https://arxiv.org/abs/1907.11692) by Yinhan Liu, Myle Ott et al.
These implementations have been tested on several datasets (see the example scripts) and should match the performances of the original implementations (e.g. ~93 F1 on SQuAD for BERT Whole-Word-Masking, ~88 F1 on RocStories for OpenAI GPT, ~18.3 perplexity on WikiText 103 for Transformer-XL, ~0.916 Peason R coefficient on STS-B for XLNet). You can find more details on the performances in the Examples section of the [documentation](https://huggingface.co/pytorch-transformers/examples.html).

View File

@ -0,0 +1,164 @@
# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert RoBERTa checkpoint."""
from __future__ import absolute_import, division, print_function
import argparse
import logging
import numpy as np
import torch
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
BertIntermediate, BertLayer,
BertModel, BertOutput,
BertSelfAttention,
BertSelfOutput)
from pytorch_transformers.modeling_roberta import (RobertaEmbeddings,
RobertaForMaskedLM,
RobertaModel)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path):
"""
Copy/paste/tweak roberta's weights to our BERT structure.
"""
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout
config = BertConfig(
vocab_size_or_config_json_file=50265,
hidden_size=roberta.args.encoder_embed_dim,
num_hidden_layers=roberta.args.encoder_layers,
num_attention_heads=roberta.args.encoder_attention_heads,
intermediate_size=roberta.args.encoder_ffn_embed_dim,
max_position_embeddings=514,
type_vocab_size=1,
)
print("Our BERT config:", config)
model = RobertaForMaskedLM(config)
model.eval()
# Now let's copy all the weights.
# Embeddings
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
model.roberta.embeddings.LayerNorm.variance_epsilon = roberta_sent_encoder.emb_layer_norm.eps
for i in range(config.num_hidden_layers):
# Encoder: start of layer
layer: BertLayer = model.roberta.encoder.layer[i]
roberta_layer: TransformerSentenceEncoderLayer = roberta_sent_encoder.layers[i]
### self attention
self_attn: BertSelfAttention = layer.attention.self
assert(
roberta_layer.self_attn.in_proj_weight.shape == torch.Size((3 * config.hidden_size, config.hidden_size))
)
# we use three distinct linear layers so we split the source layer here.
self_attn.query.weight.data = roberta_layer.self_attn.in_proj_weight[:config.hidden_size, :]
self_attn.query.bias.data = roberta_layer.self_attn.in_proj_bias[:config.hidden_size]
self_attn.key.weight.data = roberta_layer.self_attn.in_proj_weight[config.hidden_size:2*config.hidden_size, :]
self_attn.key.bias.data = roberta_layer.self_attn.in_proj_bias[config.hidden_size:2*config.hidden_size]
self_attn.value.weight.data = roberta_layer.self_attn.in_proj_weight[2*config.hidden_size:, :]
self_attn.value.bias.data = roberta_layer.self_attn.in_proj_bias[2*config.hidden_size:]
### self-attention output
self_output: BertSelfOutput = layer.attention.output
assert(
self_output.dense.weight.shape == roberta_layer.self_attn.out_proj.weight.shape
)
self_output.dense.weight = roberta_layer.self_attn.out_proj.weight
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
self_output.LayerNorm.variance_epsilon = roberta_layer.self_attn_layer_norm.eps
### intermediate
intermediate: BertIntermediate = layer.intermediate
assert(
intermediate.dense.weight.shape == roberta_layer.fc1.weight.shape
)
intermediate.dense.weight = roberta_layer.fc1.weight
intermediate.dense.bias = roberta_layer.fc1.bias
### output
bert_output: BertOutput = layer.output
assert(
bert_output.dense.weight.shape == roberta_layer.fc2.weight.shape
)
bert_output.dense.weight = roberta_layer.fc2.weight
bert_output.dense.bias = roberta_layer.fc2.bias
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
bert_output.LayerNorm.variance_epsilon = roberta_layer.final_layer_norm.eps
#### end of layer
# LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
model.lm_head.weight = roberta.model.decoder.lm_head.weight
model.lm_head.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
our_output = model(input_ids)[0]
their_output = roberta.model(input_ids)[0]
print(our_output.shape, their_output.shape)
success = torch.allclose(our_output, their_output, atol=1e-3)
print(
"Do both models output the same tensors?",
"🔥" if success else "💩"
)
if not success:
raise Exception("Something went wRoNg")
print(f"Saving model to {pytorch_dump_folder_path}")
model.save_pretrained(pytorch_dump_folder_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
## Required parameters
parser.add_argument("--roberta_checkpoint_path",
default = None,
type = str,
required = True,
help = "Path the official PyTorch dump.")
parser.add_argument("--pytorch_dump_folder_path",
default = None,
type = str,
required = True,
help = "Path to the output PyTorch model.")
args = parser.parse_args()
convert_roberta_checkpoint_to_pytorch(
args.roberta_checkpoint_path,
args.pytorch_dump_folder_path
)

View File

@ -0,0 +1,128 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch RoBERTa model. """
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_transformers.modeling_bert import (BertConfig, BertEmbeddings,
BertLayerNorm, BertModel,
BertPreTrainedModel, gelu)
logger = logging.getLogger(__name__)
ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-pytorch_model.bin",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-pytorch_model.bin",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-pytorch_model.bin",
}
ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP = {
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-config.json",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-config.json",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-large-mnli-config.json",
}
class RobertaEmbeddings(BertEmbeddings):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config):
super(RobertaEmbeddings, self).__init__(config)
self.padding_idx = 1
def forward(self, input_ids, token_type_ids=None, position_ids=None):
seq_length = input_ids.size(1)
if position_ids is None:
# Position numbers begin at padding_idx+1. Padding symbols are ignored.
# cf. fairseq's `utils.make_positions`
position_ids = torch.arange(self.padding_idx+1, seq_length+self.padding_idx+1, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
return super().forward(input_ids, token_type_ids=token_type_ids, position_ids=position_ids)
class RobertaConfig(BertConfig):
pretrained_config_archive_map = ROBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP
class RobertaModel(BertModel):
"""
Same as BertModel with:
- a tiny embeddings tweak.
- setup for Roberta pretrained models
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaModel, self).__init__(config)
self.embeddings = RobertaEmbeddings(config)
class RobertaForMaskedLM(BertPreTrainedModel):
"""
Roberta Model with a `language modeling` head on top.
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaForMaskedLM, self).__init__(config)
self.roberta = RobertaModel(config)
self.lm_head = RobertaLMHead(config)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, position_ids=None, head_mask=None):
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)
outputs = (prediction_scores,) + outputs[2:]
return outputs
class RobertaLMHead(nn.Module):
"""Roberta Head for masked language modeling."""
def __init__(self, config: BertConfig):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.layer_norm = BertLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.weight = nn.Linear(config.hidden_size, config.vocab_size, bias=False).weight
self.bias = nn.Parameter(torch.zeros(config.vocab_size))
def forward(self, features, **kwargs):
x = self.dense(features)
x = gelu(x)
x = self.layer_norm(x)
# project back to size of vocabulary with bias
x = F.linear(x, self.weight) + self.bias
return x

View File

@ -0,0 +1,69 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import os
import unittest
import pytest
import torch
from pytorch_transformers.modeling_roberta import (RobertaForMaskedLM,
RobertaModel)
class RobertaModelTest(unittest.TestCase):
# @pytest.mark.slow
def test_inference_masked_lm(self):
model = RobertaForMaskedLM.from_pretrained('roberta-base')
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 50265))
self.assertEqual(
output.shape,
expected_shape
)
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[33.8843, -4.3107, 22.7779],
[ 4.6533, -2.8099, 13.6252],
[ 1.8222, -3.6898, 8.8600]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
# @pytest.mark.slow
def test_inference_no_head(self):
model = RobertaModel.from_pretrained('roberta-base')
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[-0.0231, 0.0782, 0.0074],
[-0.1854, 0.0539, -0.0174],
[ 0.0548, 0.0799, 0.1687]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,42 @@
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import os
import unittest
import pytest
from pytorch_transformers.tokenization_roberta import RobertaTokenizer
class RobertaTokenizationTest(unittest.TestCase):
# @pytest.mark.slow
def test_full_tokenizer(self):
tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
self.assertListEqual(
tokenizer.encode('Hello world!'),
[0, 31414, 232, 328, 2]
)
self.assertListEqual(
tokenizer.encode('Hello world! cécé herlolip'),
[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]
)
if __name__ == '__main__':
unittest.main()

View File

@ -0,0 +1,218 @@
# coding=utf-8
# Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tokenization classes for RoBERTa."""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import json
import logging
import re
from .tokenization_utils import PreTrainedTokenizer
from .tokenization_gpt2 import GPT2Tokenizer
logger = logging.getLogger(__name__)
VOCAB_FILES_NAMES = {
'dict_file': 'dict.txt',
}
PRETRAINED_VOCAB_FILES_MAP = {
'dict_file':
{
'roberta-base': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
'roberta-large': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
'roberta-large-mnli': "https://s3.amazonaws.com/models.huggingface.co/bert/roberta-base-dict.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'roberta-base': 512,
'roberta-large': 512,
'roberta-large-mnli': 512,
}
SPACE_NORMALIZER = re.compile(r"\s+")
def tokenize_line(line):
line = SPACE_NORMALIZER.sub(" ", line)
line = line.strip()
return line.split()
class Dictionary(object):
"""
A mapping from symbols to consecutive integers
From Facebook's fairseq.
"""
def __init__(
self,
pad='<pad>',
eos='</s>',
unk='<unk>',
bos='<s>',
extra_special_symbols=None,
):
self.unk_word, self.pad_word, self.eos_word = unk, pad, eos
self.symbols = []
self.count = []
self.indices = {}
self.bos_index = self.add_symbol(bos)
self.pad_index = self.add_symbol(pad)
self.eos_index = self.add_symbol(eos)
self.unk_index = self.add_symbol(unk)
if extra_special_symbols:
for s in extra_special_symbols:
self.add_symbol(s)
self.nspecial = len(self.symbols)
def __getitem__(self, idx):
if idx < len(self.symbols):
return self.symbols[idx]
return self.unk_word
def index(self, sym):
"""Returns the index of the specified symbol"""
assert isinstance(sym, str)
if sym in self.indices:
return self.indices[sym]
return self.unk_index
def add_symbol(self, word, n=1):
"""Adds a word to the dictionary"""
if word in self.indices:
idx = self.indices[word]
self.count[idx] = self.count[idx] + n
return idx
else:
idx = len(self.symbols)
self.indices[word] = idx
self.symbols.append(word)
self.count.append(n)
return idx
@classmethod
def load(cls, f, ignore_utf_errors=False):
"""Loads the dictionary from a text file with the format:
```
<symbol0> <count0>
<symbol1> <count1>
...
```
"""
d = cls()
d.add_from_file(f, ignore_utf_errors)
return d
def add_from_file(self, f, ignore_utf_errors=False):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if isinstance(f, str):
try:
if not ignore_utf_errors:
with open(f, 'r', encoding='utf-8') as fd:
self.add_from_file(fd)
else:
with open(f, 'r', encoding='utf-8', errors='ignore') as fd:
self.add_from_file(fd)
except FileNotFoundError as fnfe:
raise fnfe
except UnicodeError:
raise Exception("Incorrect encoding detected in {}, please "
"rebuild the dataset".format(f))
return
lines = f.readlines()
for line in lines:
idx = line.rfind(' ')
if idx == -1:
raise ValueError("Incorrect dictionary format, expected '<token> <cnt>'")
word = line[:idx]
count = int(line[idx + 1:])
self.indices[word] = len(self.symbols)
self.symbols.append(word)
self.count.append(count)
def encode_line(self, line, line_tokenizer=tokenize_line, add_if_not_exist=True,
consumer=None, append_eos=True, reverse_order=False):
words = line_tokenizer(line)
if reverse_order:
words = list(reversed(words))
nwords = len(words)
ids = [0] * (nwords + 1 if append_eos else nwords)
for i, word in enumerate(words):
if add_if_not_exist:
idx = self.add_symbol(word)
else:
idx = self.index(word)
if consumer is not None:
consumer(word, idx)
ids[i] = idx
if append_eos:
ids[nwords] = self.eos_index
return ids
class RobertaTokenizer(PreTrainedTokenizer):
"""
RoBERTa tokenizer. Peculiarities:
- GPT-2 tokenizer with a different integer mapping on top.
"""
vocab_files_names = VOCAB_FILES_NAMES
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def __init__(self, dict_file,
bos_token="<s>", eos_token="</s>", **kwargs):
super(RobertaTokenizer, self).__init__(bos_token=bos_token, eos_token=eos_token, **kwargs)
self.gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
self.dictionary = Dictionary.load(dict_file)
def _tokenize(self, text):
""" Use GPT-2 Tokenizer """
return self.gpt2_tokenizer._tokenize(text)
def encode(self, text):
""" Converts a string in a sequence of ids (integer), using the tokenizer and vocabulary.
"""
gpt2_tokens_joined = " ".join(
str(x) for x in self.gpt2_tokenizer.convert_tokens_to_ids(self.tokenize(text))
)
bpe_sentence = '<s> ' + gpt2_tokens_joined + ' </s>'
return self.dictionary.encode_line(bpe_sentence, append_eos=False)
def _convert_token_to_id(self, token):
return self.dictionary.index(token)
def _convert_id_to_token(self, index):
symbol = self.dictionary[index]
try:
idx = int(symbol)
return self.gpt2_tokenizer._convert_id_to_token(idx)
except:
return symbol
def convert_tokens_to_string(self, tokens):
return self.gpt2_tokenizer.convert_tokens_to_string(tokens)