mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Fix roberta checkpoint conversion script (#3642)
This commit is contained in:
parent
11cc1e168b
commit
5aa8a278a3
@ -25,15 +25,8 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
|||||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||||
from packaging import version
|
from packaging import version
|
||||||
|
|
||||||
from transformers.modeling_bert import (
|
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
|
||||||
BertConfig,
|
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||||
BertIntermediate,
|
|
||||||
BertLayer,
|
|
||||||
BertOutput,
|
|
||||||
BertSelfAttention,
|
|
||||||
BertSelfOutput,
|
|
||||||
)
|
|
||||||
from transformers.modeling_roberta import RobertaForMaskedLM, RobertaForSequenceClassification
|
|
||||||
|
|
||||||
|
|
||||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||||
@ -55,7 +48,7 @@ def convert_roberta_checkpoint_to_pytorch(
|
|||||||
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
|
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
|
||||||
roberta.eval() # disable dropout
|
roberta.eval() # disable dropout
|
||||||
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
|
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
|
||||||
config = BertConfig(
|
config = RobertaConfig(
|
||||||
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
|
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
|
||||||
hidden_size=roberta.args.encoder_embed_dim,
|
hidden_size=roberta.args.encoder_embed_dim,
|
||||||
num_hidden_layers=roberta.args.encoder_layers,
|
num_hidden_layers=roberta.args.encoder_layers,
|
||||||
@ -138,7 +131,7 @@ def convert_roberta_checkpoint_to_pytorch(
|
|||||||
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
||||||
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
||||||
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
|
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
|
||||||
model.lm_head.bias = roberta.model.decoder.lm_head.bias
|
model.lm_head.decoder.bias = roberta.model.decoder.lm_head.bias
|
||||||
|
|
||||||
# Let's check that we get the same results.
|
# Let's check that we get the same results.
|
||||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||||
|
Loading…
Reference in New Issue
Block a user