mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[RoBERTa] RobertaForSequenceClassification + conversion
This commit is contained in:
parent
d2cc6b101e
commit
9d0603148b
@ -30,6 +30,7 @@ from pytorch_transformers.modeling_bert import (BertConfig, BertEncoder,
|
||||
BertSelfOutput)
|
||||
from pytorch_transformers.modeling_roberta import (RobertaEmbeddings,
|
||||
RobertaForMaskedLM,
|
||||
RobertaForSequenceClassification,
|
||||
RobertaModel)
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@ -38,7 +39,7 @@ logger = logging.getLogger(__name__)
|
||||
SAMPLE_TEXT = 'Hello world! cécé herlolip'
|
||||
|
||||
|
||||
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path):
|
||||
def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_folder_path, classification_head):
|
||||
"""
|
||||
Copy/paste/tweak roberta's weights to our BERT structure.
|
||||
"""
|
||||
@ -53,9 +54,11 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
max_position_embeddings=514,
|
||||
type_vocab_size=1,
|
||||
)
|
||||
if classification_head:
|
||||
config.num_labels = roberta.args.num_classes
|
||||
print("Our BERT config:", config)
|
||||
|
||||
model = RobertaForMaskedLM(config)
|
||||
model = RobertaForSequenceClassification(config) if classification_head else RobertaForMaskedLM(config)
|
||||
model.eval()
|
||||
|
||||
# Now let's copy all the weights.
|
||||
@ -117,14 +120,20 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
bert_output.LayerNorm.variance_epsilon = roberta_layer.final_layer_norm.eps
|
||||
#### end of layer
|
||||
|
||||
# LM Head
|
||||
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
|
||||
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
|
||||
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
||||
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
||||
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
|
||||
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
|
||||
model.lm_head.bias = roberta.model.decoder.lm_head.bias
|
||||
if classification_head:
|
||||
model.classifier.dense.weight = roberta.model.classification_heads['mnli'].dense.weight
|
||||
model.classifier.dense.bias = roberta.model.classification_heads['mnli'].dense.bias
|
||||
model.classifier.out_proj.weight = roberta.model.classification_heads['mnli'].out_proj.weight
|
||||
model.classifier.out_proj.bias = roberta.model.classification_heads['mnli'].out_proj.bias
|
||||
else:
|
||||
# LM Head
|
||||
model.lm_head.dense.weight = roberta.model.decoder.lm_head.dense.weight
|
||||
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
|
||||
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
||||
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
||||
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
|
||||
model.lm_head.weight = roberta.model.decoder.lm_head.weight
|
||||
model.lm_head.bias = roberta.model.decoder.lm_head.bias
|
||||
|
||||
# Let's check that we get the same results.
|
||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||
@ -157,8 +166,13 @@ if __name__ == "__main__":
|
||||
type = str,
|
||||
required = True,
|
||||
help = "Path to the output PyTorch model.")
|
||||
parser.add_argument("--classification_head",
|
||||
action = "store_true",
|
||||
help = "Whether to convert a final classification head.")
|
||||
args = parser.parse_args()
|
||||
convert_roberta_checkpoint_to_pytorch(
|
||||
args.roberta_checkpoint_path,
|
||||
args.pytorch_dump_folder_path
|
||||
args.pytorch_dump_folder_path,
|
||||
args.classification_head
|
||||
)
|
||||
|
||||
|
@ -142,3 +142,60 @@ class RobertaLMHead(nn.Module):
|
||||
x = self.decoder(x) + self.bias
|
||||
|
||||
return x
|
||||
|
||||
|
||||
|
||||
class RobertaForSequenceClassification(BertPreTrainedModel):
|
||||
"""
|
||||
Roberta Model with a classifier head on top.
|
||||
"""
|
||||
config_class = RobertaConfig
|
||||
pretrained_model_archive_map = ROBERTA_PRETRAINED_MODEL_ARCHIVE_MAP
|
||||
base_model_prefix = "roberta"
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaForSequenceClassification, self).__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
|
||||
self.roberta = RobertaModel(config)
|
||||
self.classifier = RobertaClassificationHead(config)
|
||||
|
||||
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
|
||||
position_ids=None, head_mask=None):
|
||||
outputs = self.roberta(input_ids, position_ids=position_ids, token_type_ids=token_type_ids,
|
||||
attention_mask=attention_mask, head_mask=head_mask)
|
||||
sequence_output = outputs[0]
|
||||
logits = self.classifier(sequence_output)
|
||||
|
||||
outputs = (logits,) + outputs[2:]
|
||||
if labels is not None:
|
||||
if self.num_labels == 1:
|
||||
# We are doing regression
|
||||
loss_fct = MSELoss()
|
||||
loss = loss_fct(logits.view(-1), labels.view(-1))
|
||||
else:
|
||||
loss_fct = CrossEntropyLoss()
|
||||
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
||||
outputs = (loss,) + outputs
|
||||
|
||||
return outputs # (loss), logits, (hidden_states), (attentions)
|
||||
|
||||
|
||||
|
||||
class RobertaClassificationHead(nn.Module):
|
||||
"""Head for sentence-level classification tasks."""
|
||||
|
||||
def __init__(self, config):
|
||||
super(RobertaClassificationHead, self).__init__()
|
||||
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
||||
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
||||
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = features[:, 0, :] # take <s> token (equiv. to [CLS])
|
||||
x = self.dropout(x)
|
||||
x = self.dense(x)
|
||||
x = torch.tanh(x)
|
||||
x = self.dropout(x)
|
||||
x = self.out_proj(x)
|
||||
return x
|
||||
|
@ -179,5 +179,63 @@ class RobertaModelTest(CommonTestCases.CommonModelTester):
|
||||
shutil.rmtree(cache_dir)
|
||||
self.assertIsNotNone(model)
|
||||
|
||||
|
||||
|
||||
class RobertaModelIntegrationTest(unittest.TestCase):
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_inference_masked_lm(self):
|
||||
model = RobertaForMaskedLM.from_pretrained('roberta-base')
|
||||
|
||||
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 11, 50265))
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
expected_shape
|
||||
)
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.Tensor(
|
||||
[[[33.8843, -4.3107, 22.7779],
|
||||
[ 4.6533, -2.8099, 13.6252],
|
||||
[ 1.8222, -3.6898, 8.8600]]]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_inference_no_head(self):
|
||||
model = RobertaModel.from_pretrained('roberta-base')
|
||||
|
||||
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
# compare the actual values for a slice.
|
||||
expected_slice = torch.Tensor(
|
||||
[[[-0.0231, 0.0782, 0.0074],
|
||||
[-0.1854, 0.0539, -0.0174],
|
||||
[ 0.0548, 0.0799, 0.1687]]]
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3)
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_inference_classification_head(self):
|
||||
model = RobertaForSequenceClassification.from_pretrained('roberta-large-mnli')
|
||||
|
||||
input_ids = torch.tensor([[ 0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])
|
||||
output = model(input_ids)[0]
|
||||
expected_shape = torch.Size((1, 3))
|
||||
self.assertEqual(
|
||||
output.shape,
|
||||
expected_shape
|
||||
)
|
||||
expected_tensor = torch.Tensor([[-0.9469, 0.3913, 0.5118]])
|
||||
self.assertTrue(
|
||||
torch.allclose(output, expected_tensor, atol=1e-3)
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user