update to hf->tf args

This commit is contained in:
Chris 2019-05-18 17:09:08 -04:00
parent 077a5b0dc4
commit f1433db4f1

View File

@ -18,16 +18,18 @@
import os
import argparse
import numpy as np
import tensorflow as tf
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
def convert_hf_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
"""
:param model:BertModel Pytorch model instance to be converted
:param ckpt_dir: directory to save Tensorflow model
:return:
Supported HF models:
Currently supported HF models:
Y BertModel
N BertForMaskedLM
N BertForPreTraining
@ -35,20 +37,13 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
N BertForNextSentencePrediction
N BertForSequenceClassification
N BertForQuestionAnswering
Note:
To keep tf out of package-level requirements, it's imported locally.
"""
import tensorflow as tf
if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
session = tf.Session()
state_dict = model.state_dict()
tf_vars = []
def to_tf_var_name(name:str):
@ -61,6 +56,7 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
name = name.replace('LayerNorm/weight', 'LayerNorm/gamma')
name = name.replace('LayerNorm/bias', 'LayerNorm/beta')
name = name.replace('weight', 'kernel')
# name += ':0'
return 'bert/{}'.format(name)
def assign_tf_var(tensor:np.ndarray, name:str):
@ -81,44 +77,35 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))
saver = tf.train.Saver(tf_vars)
saver.save(session, os.path.join(ckpt_dir, 'model'))
saver.save(session, os.path.join(ckpt_dir, args.pytorch_model_name))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_name_or_path",
parser.add_argument("--pytorch_model_dir",
default=None,
type=str,
required=True,
help="pretrained_model_name_or_path: either: \
- a str with the name of a pre-trained model to load selected in the list of: \
. `bert-base-uncased` \
. `bert-large-uncased` \
. `bert-base-cased` \
. `bert-large-cased` \
. `bert-base-multilingual-uncased` \
. `bert-base-multilingual-cased` \
. `bert-base-chinese` \
- a path or url to a pretrained model archive containing: \
. `bert_config.json` a configuration file for the model \
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance \
- a path or url to a pretrained model archive containing: \
. `bert_config.json` a configuration file for the model \
. `model.ckpt` a TensorFlow checkpoint")
help="Directory containing pytorch model")
parser.add_argument("--pytorch_model_name",
default=None,
type=str,
required=True,
help="model name (e.g. bert-base-uncased)")
parser.add_argument("--config_file_path",
default=None,
type=str,
required=True,
help="Path to bert config file.")
parser.add_argument("--cache_dir",
default=None,
help="Path to bert config file")
parser.add_argument("--tf_checkpoint_dir",
default="",
type=str,
required=True,
help="Path to a folder in which the TF model will be cached.")
help="Directory in which to save tensorflow model")
args = parser.parse_args()
model = BertModel(
config=BertConfig(args.config_file_path)
).from_pretrained(args.pretrained_model_name_or_path)
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.cache_dir)
).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)