diff --git a/convert_tf_checkpoint_to_pytorch.py b/convert_tf_checkpoint_to_pytorch.py index dfcdbee42d5..408951a9364 100644 --- a/convert_tf_checkpoint_to_pytorch.py +++ b/convert_tf_checkpoint_to_pytorch.py @@ -26,35 +26,14 @@ import numpy as np from modeling import BertConfig, BertModel -parser = argparse.ArgumentParser() -## Required parameters -parser.add_argument("--tf_checkpoint_path", - default = None, - type = str, - required = True, - help = "Path the TensorFlow checkpoint path.") -parser.add_argument("--bert_config_file", - default = None, - type = str, - required = True, - help = "The config json file corresponding to the pre-trained BERT model. \n" - "This specifies the model architecture.") -parser.add_argument("--pytorch_dump_path", - default = None, - type = str, - required = True, - help = "Path to the output PyTorch model.") - -args = parser.parse_args() - -def convert(): +def convert(config_path, ckpt_path, out_path=None): # Initialise PyTorch model - config = BertConfig.from_json_file(args.bert_config_file) + config = BertConfig.from_json_file(config_path) model = BertModel(config) # Load weights from TF model - path = args.tf_checkpoint_path + path = ckpt_path print("Converting TensorFlow checkpoint from {}".format(path)) init_vars = tf.train.list_variables(path) @@ -99,7 +78,32 @@ def convert(): pointer.data = torch.from_numpy(array) # Save pytorch-model - torch.save(model.state_dict(), args.pytorch_dump_path) + if out_path is not None: + torch.save(model.state_dict(), out_path) + return model + if __name__ == "__main__": - convert() + parser = argparse.ArgumentParser() + + ## Required parameters + parser.add_argument("--tf_checkpoint_path", + default=None, + type=str, + required=True, + help="Path the TensorFlow checkpoint path.") + parser.add_argument("--bert_config_file", + default=None, + type=str, + required=True, + help="The config json file corresponding to the pre-trained BERT model. \n" + "This specifies the model architecture.") + parser.add_argument("--pytorch_dump_path", + default=None, + type=str, + required=False, + help="Path to the output PyTorch model.") + + args = parser.parse_args() + print(args) + convert(args.bert_config_file, args.tf_checkpoint_path, args.pytorch_dump_path) diff --git a/modeling.py b/modeling.py index c467e8266ef..4cbb99f2fab 100644 --- a/modeling.py +++ b/modeling.py @@ -355,7 +355,7 @@ class BertModel(nn.Module): all_encoder_layers = self.encoder(embedding_output, extended_attention_mask) sequence_output = all_encoder_layers[-1] pooled_output = self.pooler(sequence_output) - return all_encoder_layers, pooled_output + return [embedding_output] + all_encoder_layers, pooled_output class BertForSequenceClassification(nn.Module): """BERT model for classification. diff --git a/tests/mytest.py b/tests/mytest.py new file mode 100644 index 00000000000..2b2dadecda9 --- /dev/null +++ b/tests/mytest.py @@ -0,0 +1,71 @@ +import unittest +import json +import random + +import torch +import numpy as np + +import modeling +import convert_tf_checkpoint_to_pytorch + +import grouch + + +class MyTest(unittest.TestCase): + def test_loading_and_running(self): + bertpath = "../../grouch/data/bert/bert-base/" + configpath = bertpath + "bert_config.json" + ckptpath = bertpath + "bert_model.ckpt" + m = convert_tf_checkpoint_to_pytorch.convert(configpath, ckptpath) + m.eval() + # print(m) + + input_ids = torch.LongTensor([[31, 51, 99], [15, 5, 0]]) + input_mask = torch.LongTensor([[1, 1, 1], [1, 1, 0]]) + token_type_ids = torch.LongTensor([[0, 0, 1], [0, 1, 0]]) + + all_y, pool_y = m(input_ids, token_type_ids, input_mask) + print(pool_y.shape) + # np.save("_bert_ref_pool_out.npy", pool_y.detach().numpy()) + # np.save("_bert_ref_all_out.npy", torch.stack(all_y, 0).detach().numpy()) + + config = grouch.TransformerBERT.load_config(configpath) + gm = grouch.TransformerBERT.init_from_config(config) + gm.load_weights_from_tf_checkpoint(ckptpath) + gm.eval() + + g_all_y, g_pool_y = gm(input_ids, token_type_ids, input_mask) + print(g_pool_y.shape) + + # check embeddings + # print(m.embeddings) + # print(gm.emb) + # hugging_emb = m.embeddings(input_ids, token_type_ids) + # grouch_emb = gm.emb(input_ids, token_type_ids) + + print((all_y[0] - g_all_y[0]).norm()) + # print(all_y[0][:, :, :10] - g_all_y[0][:, :, :10]) + self.assertTrue(np.allclose(all_y[0].detach().numpy(), g_all_y[0].detach().numpy(), atol=1e-7)) + print("embeddings good") + + print(m.encoder.layer[0]) + print(gm.encoder.layers[0]) + print("norm of diff at layer 1", (all_y[1] - g_all_y[1]).norm()) + # print(all_y[1][:, :, :10] - g_all_y[1][:, :, :10]) + self.assertTrue(np.allclose(all_y[1].detach().numpy(), g_all_y[1].detach().numpy(), atol=1e-6)) + + # hugging_layer = m.encoder.layer[0] + # grouch_layer = gm.encoder.layers[0] + # print("comparing weights") + # print((hugging_layer.attention.self.query.weight - grouch_layer.slf_attn.q_proj.weight).norm()) + # print((hugging_layer.attention.self.query.bias - grouch_layer.slf_attn.q_proj.bias).norm()) + # print((hugging_layer.attention.self.key.weight - grouch_layer.slf_attn.k_proj.weight).norm()) + # print((hugging_layer.attention.self.key.bias - grouch_layer.slf_attn.k_proj.bias).norm()) + # print((hugging_layer.attention.self.value.weight - grouch_layer.slf_attn.v_proj.weight).norm()) + # print((hugging_layer.attention.self.value.bias - grouch_layer.slf_attn.v_proj.bias).norm()) + # print((hugging_layer.attention.output.dense.weight - grouch_layer.slf_attn.vw_proj.weight).norm()) + # print((hugging_layer.attention.output.dense.bias - grouch_layer.slf_attn.vw_proj.bias).norm()) + + print("norm of diff at last layer", (all_y[-1] - g_all_y[-1]).norm()) + # print(all_y[-1][:, :, :10] - g_all_y[-1][:, :, :10]) + self.assertTrue(np.allclose(all_y[-1].detach().numpy(), g_all_y[-1].detach().numpy(), atol=1e-4)) \ No newline at end of file