mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Fix roberta checkpoint conversion script (#3642)
This commit is contained in:
parent
11cc1e168b
commit
5aa8a278a3
@ -25,15 +25,8 @@ from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||
from packaging import version
|
||||
|
||||
from transformers.modeling_bert import (
|
||||
BertConfig,
|
||||
BertIntermediate,
|
||||
BertLayer,
|
||||
BertOutput,
|
||||
BertSelfAttention,
|
||||
BertSelfOutput,
|
||||
)
|
||||
from transformers.modeling_roberta import RobertaForMaskedLM, RobertaForSequenceClassification
|
||||
from transformers.modeling_bert import BertIntermediate, BertLayer, BertOutput, BertSelfAttention, BertSelfOutput
|
||||
from transformers.modeling_roberta import RobertaConfig, RobertaForMaskedLM, RobertaForSequenceClassification
|
||||
|
||||
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
@ -55,7 +48,7 @@ def convert_roberta_checkpoint_to_pytorch(
|
||||
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
|
||||
roberta.eval() # disable dropout
|
||||
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
|
||||
config = BertConfig(
|
||||
config = RobertaConfig(
|
||||
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
|
||||
hidden_size=roberta.args.encoder_embed_dim,
|
||||
num_hidden_layers=roberta.args.encoder_layers,
|
||||
@ -138,7 +131,7 @@ def convert_roberta_checkpoint_to_pytorch(
|
||||
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
||||
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
||||
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
|
||||
model.lm_head.bias = roberta.model.decoder.lm_head.bias
|
||||
model.lm_head.decoder.bias = roberta.model.decoder.lm_head.bias
|
||||
|
||||
# Let's check that we get the same results.
|
||||
input_ids: torch.Tensor = roberta.encode(SAMPLE_TEXT).unsqueeze(0) # batch of size 1
|
||||
|
Loading…
Reference in New Issue
Block a user