mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-13 17:48:22 +06:00
bert weight loading from tf
This commit is contained in:
parent
907d3569c1
commit
4e52188433
@ -26,35 +26,14 @@ import numpy as np
|
|||||||
|
|
||||||
from modeling import BertConfig, BertModel
|
from modeling import BertConfig, BertModel
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
|
|
||||||
## Required parameters
|
def convert(config_path, ckpt_path, out_path=None):
|
||||||
parser.add_argument("--tf_checkpoint_path",
|
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
required = True,
|
|
||||||
help = "Path the TensorFlow checkpoint path.")
|
|
||||||
parser.add_argument("--bert_config_file",
|
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
required = True,
|
|
||||||
help = "The config json file corresponding to the pre-trained BERT model. \n"
|
|
||||||
"This specifies the model architecture.")
|
|
||||||
parser.add_argument("--pytorch_dump_path",
|
|
||||||
default = None,
|
|
||||||
type = str,
|
|
||||||
required = True,
|
|
||||||
help = "Path to the output PyTorch model.")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
def convert():
|
|
||||||
# Initialise PyTorch model
|
# Initialise PyTorch model
|
||||||
config = BertConfig.from_json_file(args.bert_config_file)
|
config = BertConfig.from_json_file(config_path)
|
||||||
model = BertModel(config)
|
model = BertModel(config)
|
||||||
|
|
||||||
# Load weights from TF model
|
# Load weights from TF model
|
||||||
path = args.tf_checkpoint_path
|
path = ckpt_path
|
||||||
print("Converting TensorFlow checkpoint from {}".format(path))
|
print("Converting TensorFlow checkpoint from {}".format(path))
|
||||||
|
|
||||||
init_vars = tf.train.list_variables(path)
|
init_vars = tf.train.list_variables(path)
|
||||||
@ -99,7 +78,32 @@ def convert():
|
|||||||
pointer.data = torch.from_numpy(array)
|
pointer.data = torch.from_numpy(array)
|
||||||
|
|
||||||
# Save pytorch-model
|
# Save pytorch-model
|
||||||
torch.save(model.state_dict(), args.pytorch_dump_path)
|
if out_path is not None:
|
||||||
|
torch.save(model.state_dict(), out_path)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
convert()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
|
## Required parameters
|
||||||
|
parser.add_argument("--tf_checkpoint_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="Path the TensorFlow checkpoint path.")
|
||||||
|
parser.add_argument("--bert_config_file",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
help="The config json file corresponding to the pre-trained BERT model. \n"
|
||||||
|
"This specifies the model architecture.")
|
||||||
|
parser.add_argument("--pytorch_dump_path",
|
||||||
|
default=None,
|
||||||
|
type=str,
|
||||||
|
required=False,
|
||||||
|
help="Path to the output PyTorch model.")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
print(args)
|
||||||
|
convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path)
|
||||||
|
@ -355,7 +355,7 @@ class BertModel(nn.Module):
|
|||||||
all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
|
all_encoder_layers = self.encoder(embedding_output, extended_attention_mask)
|
||||||
sequence_output = all_encoder_layers[-1]
|
sequence_output = all_encoder_layers[-1]
|
||||||
pooled_output = self.pooler(sequence_output)
|
pooled_output = self.pooler(sequence_output)
|
||||||
return all_encoder_layers, pooled_output
|
return [embedding_output] + all_encoder_layers, pooled_output
|
||||||
|
|
||||||
class BertForSequenceClassification(nn.Module):
|
class BertForSequenceClassification(nn.Module):
|
||||||
"""BERT model for classification.
|
"""BERT model for classification.
|
||||||
|
71
tests/mytest.py
Normal file
71
tests/mytest.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
import unittest
|
||||||
|
import json
|
||||||
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
import modeling
|
||||||
|
import convert_tf_checkpoint_to_pytorch
|
||||||
|
|
||||||
|
import grouch
|
||||||
|
|
||||||
|
|
||||||
|
class MyTest(unittest.TestCase):
|
||||||
|
def test_loading_and_running(self):
|
||||||
|
bertpath = "../../grouch/data/bert/bert-base/"
|
||||||
|
configpath = bertpath + "bert_config.json"
|
||||||
|
ckptpath = bertpath + "bert_model.ckpt"
|
||||||
|
m = convert_tf_checkpoint_to_pytorch.convert(configpath, ckptpath)
|
||||||
|
m.eval()
|
||||||
|
# print(m)
|
||||||
|
|
||||||
|
input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]])
|
||||||
|
input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]])
|
||||||
|
token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]])
|
||||||
|
|
||||||
|
all_y, pool_y = m(input_ids, token_type_ids, input_mask)
|
||||||
|
print(pool_y.shape)
|
||||||
|
# np.save("_bert_ref_pool_out.npy", pool_y.detach().numpy())
|
||||||
|
# np.save("_bert_ref_all_out.npy", torch.stack(all_y, 0).detach().numpy())
|
||||||
|
|
||||||
|
config = grouch.TransformerBERT.load_config(configpath)
|
||||||
|
gm = grouch.TransformerBERT.init_from_config(config)
|
||||||
|
gm.load_weights_from_tf_checkpoint(ckptpath)
|
||||||
|
gm.eval()
|
||||||
|
|
||||||
|
g_all_y, g_pool_y = gm(input_ids, token_type_ids, input_mask)
|
||||||
|
print(g_pool_y.shape)
|
||||||
|
|
||||||
|
# check embeddings
|
||||||
|
# print(m.embeddings)
|
||||||
|
# print(gm.emb)
|
||||||
|
# hugging_emb = m.embeddings(input_ids, token_type_ids)
|
||||||
|
# grouch_emb = gm.emb(input_ids, token_type_ids)
|
||||||
|
|
||||||
|
print((all_y[0] - g_all_y[0]).norm())
|
||||||
|
# print(all_y[0][:, :, :10] - g_all_y[0][:, :, :10])
|
||||||
|
self.assertTrue(np.allclose(all_y[0].detach().numpy(), g_all_y[0].detach().numpy(), atol=1e-7))
|
||||||
|
print("embeddings good")
|
||||||
|
|
||||||
|
print(m.encoder.layer[0])
|
||||||
|
print(gm.encoder.layers[0])
|
||||||
|
print("norm of diff at layer 1", (all_y[1] - g_all_y[1]).norm())
|
||||||
|
# print(all_y[1][:, :, :10] - g_all_y[1][:, :, :10])
|
||||||
|
self.assertTrue(np.allclose(all_y[1].detach().numpy(), g_all_y[1].detach().numpy(), atol=1e-6))
|
||||||
|
|
||||||
|
# hugging_layer = m.encoder.layer[0]
|
||||||
|
# grouch_layer = gm.encoder.layers[0]
|
||||||
|
# print("comparing weights")
|
||||||
|
# print((hugging_layer.attention.self.query.weight - grouch_layer.slf_attn.q_proj.weight).norm())
|
||||||
|
# print((hugging_layer.attention.self.query.bias - grouch_layer.slf_attn.q_proj.bias).norm())
|
||||||
|
# print((hugging_layer.attention.self.key.weight - grouch_layer.slf_attn.k_proj.weight).norm())
|
||||||
|
# print((hugging_layer.attention.self.key.bias - grouch_layer.slf_attn.k_proj.bias).norm())
|
||||||
|
# print((hugging_layer.attention.self.value.weight - grouch_layer.slf_attn.v_proj.weight).norm())
|
||||||
|
# print((hugging_layer.attention.self.value.bias - grouch_layer.slf_attn.v_proj.bias).norm())
|
||||||
|
# print((hugging_layer.attention.output.dense.weight - grouch_layer.slf_attn.vw_proj.weight).norm())
|
||||||
|
# print((hugging_layer.attention.output.dense.bias - grouch_layer.slf_attn.vw_proj.bias).norm())
|
||||||
|
|
||||||
|
print("norm of diff at last layer", (all_y[-1] - g_all_y[-1]).norm())
|
||||||
|
# print(all_y[-1][:, :, :10] - g_all_y[-1][:, :, :10])
|
||||||
|
self.assertTrue(np.allclose(all_y[-1].detach().numpy(), g_all_y[-1].detach().numpy(), atol=1e-4))
|
Loading…
Reference in New Issue
Block a user