mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
[roberta.conversion] Do not hardcode vocab size
and support for fairseq 0.9+
This commit is contained in:
parent
a4df2e0113
commit
ea636440d1
@ -22,6 +22,12 @@ import numpy as np
|
||||
import torch
|
||||
import pathlib
|
||||
|
||||
import fairseq
|
||||
from packaging import version
|
||||
|
||||
if version.parse(fairseq.__version__) < version.parse("0.9.0"):
|
||||
raise Exception("requires fairseq >= 0.9.0")
|
||||
|
||||
from fairseq.models.roberta import RobertaModel as FairseqRobertaModel
|
||||
from fairseq.modules import TransformerSentenceEncoderLayer
|
||||
from transformers.modeling_bert import (BertConfig, BertEncoder,
|
||||
@ -46,8 +52,9 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
"""
|
||||
roberta = FairseqRobertaModel.from_pretrained(roberta_checkpoint_path)
|
||||
roberta.eval() # disable dropout
|
||||
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
|
||||
config = BertConfig(
|
||||
vocab_size=50265,
|
||||
vocab_size=roberta_sent_encoder.embed_tokens.num_embeddings,
|
||||
hidden_size=roberta.args.encoder_embed_dim,
|
||||
num_hidden_layers=roberta.args.encoder_layers,
|
||||
num_attention_heads=roberta.args.encoder_attention_heads,
|
||||
@ -65,7 +72,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
|
||||
# Now let's copy all the weights.
|
||||
# Embeddings
|
||||
roberta_sent_encoder = roberta.model.decoder.sentence_encoder
|
||||
model.roberta.embeddings.word_embeddings.weight = roberta_sent_encoder.embed_tokens.weight
|
||||
model.roberta.embeddings.position_embeddings.weight = roberta_sent_encoder.embed_positions.weight
|
||||
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
|
||||
|
Loading…
Reference in New Issue
Block a user