added main() for programmatic call to convert pytorch->tf

This commit is contained in:
chrislarson1 2019-06-19 23:18:57 -04:00
parent a8e071c690
commit 716cc1c4d9

View File

@ -17,16 +17,18 @@
import os
import argparse
import torch
import numpy as np
import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
from pytorch_pretrained_bert.modeling import BertModel
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
"""
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: directory to save Tensorflow model
:param ckpt_dir: Tensorflow model directory
:param model_name: model name
:return:
Currently supported HF models:
@ -87,35 +89,42 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))
saver = tf.train.Saver(tf_vars)
saver.save(session, os.path.join(ckpt_dir, args.pytorch_model_name))
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
if __name__ == "__main__":
def main(raw_args=None):
parser = argparse.ArgumentParser()
parser.add_argument("--pytorch_model_dir",
default=None,
parser.add_argument("--model_name",
type=str,
required=True,
help="model name e.g. bert-base-uncased")
parser.add_argument("--cache_dir",
type=str,
default=None,
required=False,
help="Directory containing pytorch model")
parser.add_argument("--pytorch_model_name",
default=None,
parser.add_argument("--pytorch_model_path",
type=str,
required=True,
help="model name (e.g. bert-base-uncased)")
parser.add_argument("--config_file_path",
default=None,
type=str,
required=True,
help="Path to bert config file")
parser.add_argument("--tf_checkpoint_dir",
default="",
help="/path/to/<pytorch-model-name>.bin")
parser.add_argument("--tf_cache_dir",
type=str,
required=True,
help="Directory in which to save tensorflow model")
args = parser.parse_args()
args = parser.parse_args(raw_args)
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir
)
convert_pytorch_checkpoint_to_tf(
model=model,
ckpt_dir=args.tf_cache_dir,
model_name=args.model_name
)
model = BertModel(
config=BertConfig(args.config_file_path)
).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
if __name__ == "__main__":
main()