[RoBERTa] RobertaForSequenceClassification + conversion

This commit is contained in:
Julien Chaumond 2019-08-08 11:24:54 -04:00
parent d2cc6b101e
commit 9d0603148b
3 changed files with 140 additions and 11 deletions

View File

@ -30,6 +30,7 @@ from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
BertSelfOutput)
from pytorch_transformers.modeling_roberta import (RobertaEmbeddings,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaModel)
logging.basicConfig(level=logging.INFO)
@ -38,7 +39,7 @@ logger = logging.getLogger(__name__)
SAMPLE_TEXT = 'Hello world! cécé herlolip'
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path):
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
"""
Copy/paste/tweak roberta's weights to our BERT structure.
"""
@ -53,9 +54,11 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
max_position_embeddings=514,
type_vocab_size=1,
)
if classification_head:
config.num_labels = roberta.args.num_classes
print("Our BERT config:", config)
model = RobertaForMaskedLM(config)
model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
model.eval()
# Now let's copy all the weights.
@ -117,14 +120,20 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
bert_output.LayerNorm.variance_epsilon = roberta_layer.final_layer_norm.eps
#### end of layer
# LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
model.lm_head.bias = roberta.model.decoder.lm_head.bias
if classification_head:
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
else:
# LM Head
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
model.lm_head.weight = roberta.model.decoder.lm_head.weight
model.lm_head.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
@ -157,8 +166,13 @@ if __name__ == "__main__":
type = str,
required = True,
help = "Path to the output PyTorch model.")
parser.add_argument("--classification_head",
action = "store_true",
help = "Whether to convert a final classification head.")
args = parser.parse_args()
convert_roberta_checkpoint_to_pytorch(
args.roberta_checkpoint_path,
args.pytorch_dump_folder_path
args.pytorch_dump_folder_path,
args.classification_head
)

View File

@ -142,3 +142,60 @@ class RobertaLMHead(nn.Module):
x = self.decoder(x) + self.bias
return x
class RobertaForSequenceClassification(BertPreTrainedModel):
"""
Roberta Model with a classifier head on top.
"""
config_class = RobertaConfig
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix = "roberta"
def __init__(self, config):
super(RobertaForSequenceClassification, self).__init__(config)
self.num_labels = config.num_labels
self.roberta = RobertaModel(config)
self.classifier = RobertaClassificationHead(config)
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
position_ids=None, head_mask=None):
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
attention_mask=attention_mask, head_mask=head_mask)
sequence_output = outputs[0]
logits = self.classifier(sequence_output)
outputs = (logits,) + outputs[2:]
if labels is not None:
if self.num_labels == 1:
# We are doing regression
loss_fct = MSELoss()
loss = loss_fct(logits.view(-1), labels.view(-1))
else:
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
outputs = (loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions)
class RobertaClassificationHead(nn.Module):
"""Head for sentence-level classification tasks."""
def __init__(self, config):
super(RobertaClassificationHead, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x

View File

@ -179,5 +179,63 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
shutil.rmtree(cache_dir)
self.assertIsNotNone(model)
class RobertaModelIntegrationTest(unittest.TestCase):
@pytest.mark.slow
def test_inference_masked_lm(self):
model = RobertaForMaskedLM.from_pretrained('roberta-base')
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 11, 50265))
self.assertEqual(
output.shape,
expected_shape
)
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[33.8843, -4.3107, 22.7779],
[ 4.6533, -2.8099, 13.6252],
[ 1.8222, -3.6898, 8.8600]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
@pytest.mark.slow
def test_inference_no_head(self):
model = RobertaModel.from_pretrained('roberta-base')
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
# compare the actual values for a slice.
expected_slice = torch.Tensor(
[[[-0.0231, 0.0782, 0.0074],
[-0.1854, 0.0539, -0.0174],
[ 0.0548, 0.0799, 0.1687]]]
)
self.assertTrue(
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
)
@pytest.mark.slow
def test_inference_classification_head(self):
model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli')
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
output = model(input_ids)[0]
expected_shape = torch.Size((1, 3))
self.assertEqual(
output.shape,
expected_shape
)
expected_tensor = torch.Tensor([[-0.9469, 0.3913, 0.5118]])
self.assertTrue(
torch.allclose(output, expected_tensor, atol=1e-3)
)
if __name__ == "__main__":
unittest.main()