mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update to hf->tf args
This commit is contained in:
parent
077a5b0dc4
commit
f1433db4f1
@ -18,16 +18,18 @@
|
||||
import os
|
||||
import argparse
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
from pytorch_pretrained_bert.modeling import BertConfig, BertModel
|
||||
|
||||
|
||||
def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
|
||||
def convert_hf_checkpoint_to_tf(model:type(BertModel), ckpt_dir:str):
|
||||
|
||||
"""
|
||||
:param model:BertModel Pytorch model instance to be converted
|
||||
:param ckpt_dir: directory to save Tensorflow model
|
||||
:return:
|
||||
|
||||
Supported HF models:
|
||||
Currently supported HF models:
|
||||
Y BertModel
|
||||
N BertForMaskedLM
|
||||
N BertForPreTraining
|
||||
@ -35,20 +37,13 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
|
||||
N BertForNextSentencePrediction
|
||||
N BertForSequenceClassification
|
||||
N BertForQuestionAnswering
|
||||
|
||||
Note:
|
||||
To keep tf out of package-level requirements, it's imported locally.
|
||||
"""
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
if not os.path.isdir(ckpt_dir):
|
||||
os.makedirs(ckpt_dir)
|
||||
|
||||
session = tf.Session()
|
||||
|
||||
state_dict = model.state_dict()
|
||||
|
||||
tf_vars = []
|
||||
|
||||
def to_tf_var_name(name:str):
|
||||
@ -61,6 +56,7 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
|
||||
name = name.replace('LayerNorm/weight', 'LayerNorm/gamma')
|
||||
name = name.replace('LayerNorm/bias', 'LayerNorm/beta')
|
||||
name = name.replace('weight', 'kernel')
|
||||
# name += ':0'
|
||||
return 'bert/{}'.format(name)
|
||||
|
||||
def assign_tf_var(tensor:np.ndarray, name:str):
|
||||
@ -81,44 +77,35 @@ def convert_hf_checkpoint_to_tf(model:BertModel, ckpt_dir:str):
|
||||
print("{0}{1}initialized".format(tf_name, " " * (60 - len(tf_name))))
|
||||
|
||||
saver = tf.train.Saver(tf_vars)
|
||||
saver.save(session, os.path.join(ckpt_dir, 'model'))
|
||||
saver.save(session, os.path.join(ckpt_dir, args.pytorch_model_name))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pretrained_model_name_or_path",
|
||||
parser.add_argument("--pytorch_model_dir",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="pretrained_model_name_or_path: either: \
|
||||
- a str with the name of a pre-trained model to load selected in the list of: \
|
||||
. `bert-base-uncased` \
|
||||
. `bert-large-uncased` \
|
||||
. `bert-base-cased` \
|
||||
. `bert-large-cased` \
|
||||
. `bert-base-multilingual-uncased` \
|
||||
. `bert-base-multilingual-cased` \
|
||||
. `bert-base-chinese` \
|
||||
- a path or url to a pretrained model archive containing: \
|
||||
. `bert_config.json` a configuration file for the model \
|
||||
. `pytorch_model.bin` a PyTorch dump of a BertForPreTraining instance \
|
||||
- a path or url to a pretrained model archive containing: \
|
||||
. `bert_config.json` a configuration file for the model \
|
||||
. `model.ckpt` a TensorFlow checkpoint")
|
||||
help="Directory containing pytorch model")
|
||||
parser.add_argument("--pytorch_model_name",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="model name (e.g. bert-base-uncased)")
|
||||
parser.add_argument("--config_file_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to bert config file.")
|
||||
parser.add_argument("--cache_dir",
|
||||
default=None,
|
||||
help="Path to bert config file")
|
||||
parser.add_argument("--tf_checkpoint_dir",
|
||||
default="",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to a folder in which the TF model will be cached.")
|
||||
help="Directory in which to save tensorflow model")
|
||||
args = parser.parse_args()
|
||||
|
||||
model = BertModel(
|
||||
config=BertConfig(args.config_file_path)
|
||||
).from_pretrained(args.pretrained_model_name_or_path)
|
||||
|
||||
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.cache_dir)
|
||||
).from_pretrained(args.pytorch_model_name, cache_dir=args.pytorch_model_dir)
|
||||
convert_hf_checkpoint_to_tf(model=model, ckpt_dir=args.tf_checkpoint_dir)
|
Loading…
Reference in New Issue
Block a user