update main conversion script and readme

This commit is contained in:
thomwolf 2019-06-25 10:45:07 +02:00
parent 7de1740490
commit 603c513b35
3 changed files with 96 additions and 49 deletions

View File

@ -1690,7 +1690,7 @@ Here is an example of the conversion process for a pre-trained `BERT-Base Uncase
```shell ```shell
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12 export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \ pytorch_pretrained_bert bert \
$BERT_BASE_DIR/bert_model.ckpt \ $BERT_BASE_DIR/bert_model.ckpt \
$BERT_BASE_DIR/bert_config.json \ $BERT_BASE_DIR/bert_config.json \
$BERT_BASE_DIR/pytorch_model.bin $BERT_BASE_DIR/pytorch_model.bin
@ -1705,7 +1705,7 @@ Here is an example of the conversion process for a pre-trained OpenAI GPT model,
```shell ```shell
export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights
pytorch_pretrained_bert convert_openai_checkpoint \ pytorch_pretrained_bert gpt \
$OPENAI_GPT_CHECKPOINT_FOLDER_PATH \ $OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
$PYTORCH_DUMP_OUTPUT \ $PYTORCH_DUMP_OUTPUT \
[OPENAI_GPT_CONFIG] [OPENAI_GPT_CONFIG]
@ -1718,7 +1718,7 @@ Here is an example of the conversion process for a pre-trained Transformer-XL mo
```shell ```shell
export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint
pytorch_pretrained_bert convert_transfo_xl_checkpoint \ pytorch_pretrained_bert transfo_xl \
$TRANSFO_XL_CHECKPOINT_FOLDER_PATH \ $TRANSFO_XL_CHECKPOINT_FOLDER_PATH \
$PYTORCH_DUMP_OUTPUT \ $PYTORCH_DUMP_OUTPUT \
[TRANSFO_XL_CONFIG] [TRANSFO_XL_CONFIG]
@ -1731,12 +1731,28 @@ Here is an example of the conversion process for a pre-trained OpenAI's GPT-2 mo
```shell ```shell
export GPT2_DIR=/path/to/gpt2/checkpoint export GPT2_DIR=/path/to/gpt2/checkpoint
pytorch_pretrained_bert convert_gpt2_checkpoint \ pytorch_pretrained_bert gpt2 \
$GPT2_DIR/model.ckpt \ $GPT2_DIR/model.ckpt \
$PYTORCH_DUMP_OUTPUT \ $PYTORCH_DUMP_OUTPUT \
[GPT2_CONFIG] [GPT2_CONFIG]
``` ```
### XLNet
Here is an example of the conversion process for a pre-trained XLNet model, fine-tuned on STS-B using the TensorFlow script:
```shell
export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint
export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config
pytorch_pretrained_bert xlnet \
$TRANSFO_XL_CHECKPOINT_PATH \
$TRANSFO_XL_CONFIG_PATH \
$PYTORCH_DUMP_OUTPUT \
STS-B \
```
## TPU ## TPU
TPU support and pretraining scripts TPU support and pretraining scripts

View File

@ -1,20 +1,16 @@
# coding: utf8 # coding: utf8
def main(): def main():
import sys import sys
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [ if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet"]:
"convert_tf_checkpoint_to_pytorch",
"convert_openai_checkpoint",
"convert_transfo_xl_checkpoint",
"convert_gpt2_checkpoint",
]:
print( print(
"Should be used as one of: \n" "Should be used as one of: \n"
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n" ">> `pytorch_pretrained_bert bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n" ">> `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n" ">> `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`") ">> `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or \n"
">> `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
else: else:
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch": if sys.argv[1] == "bert":
try: try:
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
except ImportError: except ImportError:
@ -25,24 +21,28 @@ def main():
if len(sys.argv) != 5: if len(sys.argv) != 5:
# pylint: disable=line-too-long # pylint: disable=line-too-long
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`") print("Should be used as `pytorch_pretrained_bert bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
else: else:
PYTORCH_DUMP_OUTPUT = sys.argv.pop() PYTORCH_DUMP_OUTPUT = sys.argv.pop()
TF_CONFIG = sys.argv.pop() TF_CONFIG = sys.argv.pop()
TF_CHECKPOINT = sys.argv.pop() TF_CHECKPOINT = sys.argv.pop()
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "convert_openai_checkpoint": elif sys.argv[1] == "gpt":
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2] if len(sys.argv) < 4 or len(sys.argv) > 5:
PYTORCH_DUMP_OUTPUT = sys.argv[3] # pylint: disable=line-too-long
if len(sys.argv) == 5: print("Should be used as `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
OPENAI_GPT_CONFIG = sys.argv[4]
else: else:
OPENAI_GPT_CONFIG = "" OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH, PYTORCH_DUMP_OUTPUT = sys.argv[3]
OPENAI_GPT_CONFIG, if len(sys.argv) == 5:
PYTORCH_DUMP_OUTPUT) OPENAI_GPT_CONFIG = sys.argv[4]
elif sys.argv[1] == "convert_transfo_xl_checkpoint": else:
OPENAI_GPT_CONFIG = ""
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
OPENAI_GPT_CONFIG,
PYTORCH_DUMP_OUTPUT)
elif sys.argv[1] == "transfo_xl":
try: try:
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
except ImportError: except ImportError:
@ -50,20 +50,23 @@ def main():
"In that case, it requires TensorFlow to be installed. Please see " "In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
if len(sys.argv) < 4 or len(sys.argv) > 5:
if 'ckpt' in sys.argv[2].lower(): # pylint: disable=line-too-long
TF_CHECKPOINT = sys.argv[2] print("Should be used as `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
TF_DATASET_FILE = ""
else: else:
TF_DATASET_FILE = sys.argv[2] if 'ckpt' in sys.argv[2].lower():
TF_CHECKPOINT = "" TF_CHECKPOINT = sys.argv[2]
PYTORCH_DUMP_OUTPUT = sys.argv[3] TF_DATASET_FILE = ""
if len(sys.argv) == 5: else:
TF_CONFIG = sys.argv[4] TF_DATASET_FILE = sys.argv[2]
else: TF_CHECKPOINT = ""
TF_CONFIG = "" PYTORCH_DUMP_OUTPUT = sys.argv[3]
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE) if len(sys.argv) == 5:
else: TF_CONFIG = sys.argv[4]
else:
TF_CONFIG = ""
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
elif sys.argv[1] == "gpt2":
try: try:
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
except ImportError: except ImportError:
@ -72,12 +75,40 @@ def main():
"https://www.tensorflow.org/install/ for installation instructions.") "https://www.tensorflow.org/install/ for installation instructions.")
raise raise
TF_CHECKPOINT = sys.argv[2] if len(sys.argv) < 4 or len(sys.argv) > 5:
PYTORCH_DUMP_OUTPUT = sys.argv[3] # pylint: disable=line-too-long
if len(sys.argv) == 5: print("Should be used as `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
TF_CONFIG = sys.argv[4]
else: else:
TF_CONFIG = "" TF_CHECKPOINT = sys.argv[2]
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT) PYTORCH_DUMP_OUTPUT = sys.argv[3]
if len(sys.argv) == 5:
TF_CONFIG = sys.argv[4]
else:
TF_CONFIG = ""
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
else:
try:
from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
except ImportError:
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
"In that case, it requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions.")
raise
if len(sys.argv) < 5 or len(sys.argv) > 6:
# pylint: disable=line-too-long
print("Should be used as `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
else:
TF_CHECKPOINT = sys.argv[2]
TF_CONFIG = sys.argv[3]
PYTORCH_DUMP_OUTPUT = sys.argv[4]
if len(sys.argv) == 6:
FINETUNING_TASK = sys.argv[5]
convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
TF_CONFIG,
PYTORCH_DUMP_OUTPUT,
FINETUNING_TASK)
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View File

@ -70,7 +70,7 @@ if __name__ == "__main__":
required = True, required = True,
help = "The config json file corresponding to the pre-trained XLNet model. \n" help = "The config json file corresponding to the pre-trained XLNet model. \n"
"This specifies the model architecture.") "This specifies the model architecture.")
parser.add_argument("--pytorch_dump_folder_path",finetuning_task parser.add_argument("--pytorch_dump_folder_path",
default = None, default = None,
type = str, type = str,
required = True, required = True,
@ -81,6 +81,6 @@ if __name__ == "__main__":
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned") help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
args = parser.parse_args() args = parser.parse_args()
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path, convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
args.xlnet_config_file, args.xlnet_config_file,
args.pytorch_dump_folder_path, args.pytorch_dump_folder_path,
args.finetuning_task) args.finetuning_task)