mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
[RoBERTa] LayerNorm's eps is not a nn.Parameter so there's no point setting it on the model
Instead we correctly store it on the config (regenerating the hosted config files) cc @lysandrejik
This commit is contained in:
parent
09363f2a8b
commit
574c5b3a72
@ -53,6 +53,7 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
intermediate_size=roberta.args.encoder_ffn_embed_dim,
|
||||
max_position_embeddings=514,
|
||||
type_vocab_size=1,
|
||||
layer_norm_eps=1e-5, # PyTorch default used in fairseq
|
||||
)
|
||||
if classification_head:
|
||||
config.num_labels = roberta.args.num_classes
|
||||
@ -69,7 +70,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
model.roberta.embeddings.token_type_embeddings.weight.data = torch.zeros_like(model.roberta.embeddings.token_type_embeddings.weight) # just zero them out b/c RoBERTa doesn't use them.
|
||||
model.roberta.embeddings.LayerNorm.weight = roberta_sent_encoder.emb_layer_norm.weight
|
||||
model.roberta.embeddings.LayerNorm.bias = roberta_sent_encoder.emb_layer_norm.bias
|
||||
model.roberta.embeddings.LayerNorm.variance_epsilon = roberta_sent_encoder.emb_layer_norm.eps
|
||||
|
||||
for i in range(config.num_hidden_layers):
|
||||
# Encoder: start of layer
|
||||
@ -98,7 +98,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
self_output.dense.bias = roberta_layer.self_attn.out_proj.bias
|
||||
self_output.LayerNorm.weight = roberta_layer.self_attn_layer_norm.weight
|
||||
self_output.LayerNorm.bias = roberta_layer.self_attn_layer_norm.bias
|
||||
self_output.LayerNorm.variance_epsilon = roberta_layer.self_attn_layer_norm.eps
|
||||
|
||||
### intermediate
|
||||
intermediate: BertIntermediate = layer.intermediate
|
||||
@ -117,7 +116,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
bert_output.dense.bias = roberta_layer.fc2.bias
|
||||
bert_output.LayerNorm.weight = roberta_layer.final_layer_norm.weight
|
||||
bert_output.LayerNorm.bias = roberta_layer.final_layer_norm.bias
|
||||
bert_output.LayerNorm.variance_epsilon = roberta_layer.final_layer_norm.eps
|
||||
#### end of layer
|
||||
|
||||
if classification_head:
|
||||
@ -131,7 +129,6 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
model.lm_head.dense.bias = roberta.model.decoder.lm_head.dense.bias
|
||||
model.lm_head.layer_norm.weight = roberta.model.decoder.lm_head.layer_norm.weight
|
||||
model.lm_head.layer_norm.bias = roberta.model.decoder.lm_head.layer_norm.bias
|
||||
model.lm_head.layer_norm.variance_epsilon = roberta.model.decoder.lm_head.layer_norm.eps
|
||||
model.lm_head.decoder.weight = roberta.model.decoder.lm_head.weight
|
||||
model.lm_head.bias = roberta.model.decoder.lm_head.bias
|
||||
|
||||
@ -144,6 +141,8 @@ def convert_roberta_checkpoint_to_pytorch(roberta_checkpoint_path, pytorch_dump_
|
||||
else:
|
||||
their_output = roberta.model(input_ids)[0]
|
||||
print(our_output.shape, their_output.shape)
|
||||
max_absolute_diff = torch.max(torch.abs(our_output - their_output)).item()
|
||||
print(f"max_absolute_diff = {max_absolute_diff}") # ~ 1e-7
|
||||
success = torch.allclose(our_output, their_output, atol=1e-3)
|
||||
print(
|
||||
"Do both models output the same tensors?",
|
||||
|
Loading…
Reference in New Issue
Block a user