mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
added main() for programmatic call to convert pytorch->tf
This commit is contained in:
parent
a8e071c690
commit
716cc1c4d9
@ -17,16 +17,18 @@
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import torch
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
|
||||
from pytorch_pretrained_bert.modeling import BertModel
|
||||
|
||||
|
||||
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
|
||||
def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str, model_name:str):
|
||||
|
||||
"""
|
||||
:param model:BertModel Pytorch model instance to be converted
|
||||
:param ckpt_dir: directory to save Tensorflow model
|
||||
:param ckpt_dir: Tensorflow model directory
|
||||
:param model_name: model name
|
||||
:return:
|
||||
|
||||
Currently supported HF models:
|
||||
@ -87,35 +89,42 @@ def convert_pytorch_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
|
||||
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))
|
||||
|
||||
saver = tf.train.Saver(tf_vars)
|
||||
saver.save(session, os.path.join(ckpt_dir, args.pytorch_model_name))
|
||||
saver.save(session, os.path.join(ckpt_dir, model_name.replace("-", "_") + ".ckpt"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def main(raw_args=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pytorch_model_dir",
|
||||
default=None,
|
||||
parser.add_argument("--model_name",
|
||||
type=str,
|
||||
required=True,
|
||||
help="model name e.g. bert-base-uncased")
|
||||
parser.add_argument("--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Directory containing pytorch model")
|
||||
parser.add_argument("--pytorch_model_name",
|
||||
default=None,
|
||||
parser.add_argument("--pytorch_model_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="model name (e.g. bert-base-uncased)")
|
||||
parser.add_argument("--config_file_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to bert config file")
|
||||
parser.add_argument("--tf_checkpoint_dir",
|
||||
default="",
|
||||
help="/path/to/<pytorch-model-name>.bin")
|
||||
parser.add_argument("--tf_cache_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory in which to save tensorflow model")
|
||||
args = parser.parse_args()
|
||||
args = parser.parse_args(raw_args)
|
||||
|
||||
model = BertModel.from_pretrained(
|
||||
pretrained_model_name_or_path=args.model_name,
|
||||
state_dict=torch.load(args.pytorch_model_path),
|
||||
cache_dir=args.cache_dir
|
||||
)
|
||||
|
||||
convert_pytorch_checkpoint_to_tf(
|
||||
model=model,
|
||||
ckpt_dir=args.tf_cache_dir,
|
||||
model_name=args.model_name
|
||||
)
|
||||
|
||||
model = BertModel(
|
||||
config=BertConfig(args.config_file_path)
|
||||
).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
|
||||
convert_pytorch_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
Loading…
Reference in New Issue
Block a user