Fix roberta checkpoint conversion script (#3642)

This commit is contained in:
Myle Ott 2020-04-07 12:03:23 -04:00 committed by GitHub
parent 11cc1e168b
commit 5aa8a278a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -25,15 +25,8 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
from fairseq.modules import TransformerSentenceEncoderLayer
from packaging import version
from transformers.modeling_bert import (
BertConfig,
BertIntermediate,
BertLayer,
BertOutput,
BertSelfAttention,
BertSelfOutput,
)
from transformers.modeling_roberta import RobertaForMaskedLM, RobertaForSequenceClassification
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
@ -55,7 +48,7 @@ def convert_roberta_checkpoint_to_pytorch(
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
roberta.eval() # disable dropout
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
config = BertConfig(
config = RobertaConfig(
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
hidden_size=roberta.args.encoder_embed_dim,
num_hidden_layers=roberta.args.encoder_layers,
@ -138,7 +131,7 @@ def convert_roberta_checkpoint_to_pytorch(
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
model.lm_head.bias = roberta.model.decoder.lm_head.bias
model.lm_head.decoder.bias = roberta.model.decoder.lm_head.bias
# Let's check that we get the same results.
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1