mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
update main conversion script and readme
This commit is contained in:
parent
7de1740490
commit
603c513b35
24
README.md
24
README.md
@ -1690,7 +1690,7 @@ Here is an example of the conversion process for a pre-trained `BERT-Base Uncase
|
|||||||
```shell
|
```shell
|
||||||
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
|
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12
|
||||||
|
|
||||||
pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch \
|
pytorch_pretrained_bert bert \
|
||||||
$BERT_BASE_DIR/bert_model.ckpt \
|
$BERT_BASE_DIR/bert_model.ckpt \
|
||||||
$BERT_BASE_DIR/bert_config.json \
|
$BERT_BASE_DIR/bert_config.json \
|
||||||
$BERT_BASE_DIR/pytorch_model.bin
|
$BERT_BASE_DIR/pytorch_model.bin
|
||||||
@ -1705,7 +1705,7 @@ Here is an example of the conversion process for a pre-trained OpenAI GPT model,
|
|||||||
```shell
|
```shell
|
||||||
export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights
|
export OPENAI_GPT_CHECKPOINT_FOLDER_PATH=/path/to/openai/pretrained/numpy/weights
|
||||||
|
|
||||||
pytorch_pretrained_bert convert_openai_checkpoint \
|
pytorch_pretrained_bert gpt \
|
||||||
$OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
|
$OPENAI_GPT_CHECKPOINT_FOLDER_PATH \
|
||||||
$PYTORCH_DUMP_OUTPUT \
|
$PYTORCH_DUMP_OUTPUT \
|
||||||
[OPENAI_GPT_CONFIG]
|
[OPENAI_GPT_CONFIG]
|
||||||
@ -1718,7 +1718,7 @@ Here is an example of the conversion process for a pre-trained Transformer-XL mo
|
|||||||
```shell
|
```shell
|
||||||
export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint
|
export TRANSFO_XL_CHECKPOINT_FOLDER_PATH=/path/to/transfo/xl/checkpoint
|
||||||
|
|
||||||
pytorch_pretrained_bert convert_transfo_xl_checkpoint \
|
pytorch_pretrained_bert transfo_xl \
|
||||||
$TRANSFO_XL_CHECKPOINT_FOLDER_PATH \
|
$TRANSFO_XL_CHECKPOINT_FOLDER_PATH \
|
||||||
$PYTORCH_DUMP_OUTPUT \
|
$PYTORCH_DUMP_OUTPUT \
|
||||||
[TRANSFO_XL_CONFIG]
|
[TRANSFO_XL_CONFIG]
|
||||||
@ -1731,12 +1731,28 @@ Here is an example of the conversion process for a pre-trained OpenAI's GPT-2 mo
|
|||||||
```shell
|
```shell
|
||||||
export GPT2_DIR=/path/to/gpt2/checkpoint
|
export GPT2_DIR=/path/to/gpt2/checkpoint
|
||||||
|
|
||||||
pytorch_pretrained_bert convert_gpt2_checkpoint \
|
pytorch_pretrained_bert gpt2 \
|
||||||
$GPT2_DIR/model.ckpt \
|
$GPT2_DIR/model.ckpt \
|
||||||
$PYTORCH_DUMP_OUTPUT \
|
$PYTORCH_DUMP_OUTPUT \
|
||||||
[GPT2_CONFIG]
|
[GPT2_CONFIG]
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### XLNet
|
||||||
|
|
||||||
|
Here is an example of the conversion process for a pre-trained XLNet model, fine-tuned on STS-B using the TensorFlow script:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export TRANSFO_XL_CHECKPOINT_PATH=/path/to/xlnet/checkpoint
|
||||||
|
export TRANSFO_XL_CONFIG_PATH=/path/to/xlnet/config
|
||||||
|
|
||||||
|
pytorch_pretrained_bert xlnet \
|
||||||
|
$TRANSFO_XL_CHECKPOINT_PATH \
|
||||||
|
$TRANSFO_XL_CONFIG_PATH \
|
||||||
|
$PYTORCH_DUMP_OUTPUT \
|
||||||
|
STS-B \
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
## TPU
|
## TPU
|
||||||
|
|
||||||
TPU support and pretraining scripts
|
TPU support and pretraining scripts
|
||||||
|
@ -1,20 +1,16 @@
|
|||||||
# coding: utf8
|
# coding: utf8
|
||||||
def main():
|
def main():
|
||||||
import sys
|
import sys
|
||||||
if (len(sys.argv) != 4 and len(sys.argv) != 5) or sys.argv[1] not in [
|
if (len(sys.argv) < 4 or len(sys.argv) > 6) or sys.argv[1] not in ["bert", "gpt", "transfo_xl", "gpt2", "xlnet"]:
|
||||||
"convert_tf_checkpoint_to_pytorch",
|
|
||||||
"convert_openai_checkpoint",
|
|
||||||
"convert_transfo_xl_checkpoint",
|
|
||||||
"convert_gpt2_checkpoint",
|
|
||||||
]:
|
|
||||||
print(
|
print(
|
||||||
"Should be used as one of: \n"
|
"Should be used as one of: \n"
|
||||||
">> `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
|
">> `pytorch_pretrained_bert bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`, \n"
|
||||||
">> `pytorch_pretrained_bert convert_openai_checkpoint OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
|
">> `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`, \n"
|
||||||
">> `pytorch_pretrained_bert convert_transfo_xl_checkpoint TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
|
">> `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT_OR_DATASET PYTORCH_DUMP_OUTPUT [TF_CONFIG]` or \n"
|
||||||
">> `pytorch_pretrained_bert convert_gpt2_checkpoint TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]`")
|
">> `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [GPT2_CONFIG]` or \n"
|
||||||
|
">> `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
|
||||||
else:
|
else:
|
||||||
if sys.argv[1] == "convert_tf_checkpoint_to_pytorch":
|
if sys.argv[1] == "bert":
|
||||||
try:
|
try:
|
||||||
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
from .convert_tf_checkpoint_to_pytorch import convert_tf_checkpoint_to_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -25,24 +21,28 @@ def main():
|
|||||||
|
|
||||||
if len(sys.argv) != 5:
|
if len(sys.argv) != 5:
|
||||||
# pylint: disable=line-too-long
|
# pylint: disable=line-too-long
|
||||||
print("Should be used as `pytorch_pretrained_bert convert_tf_checkpoint_to_pytorch TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
|
print("Should be used as `pytorch_pretrained_bert bert TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT`")
|
||||||
else:
|
else:
|
||||||
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
|
PYTORCH_DUMP_OUTPUT = sys.argv.pop()
|
||||||
TF_CONFIG = sys.argv.pop()
|
TF_CONFIG = sys.argv.pop()
|
||||||
TF_CHECKPOINT = sys.argv.pop()
|
TF_CHECKPOINT = sys.argv.pop()
|
||||||
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
convert_tf_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||||
elif sys.argv[1] == "convert_openai_checkpoint":
|
elif sys.argv[1] == "gpt":
|
||||||
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
|
from .convert_openai_checkpoint_to_pytorch import convert_openai_checkpoint_to_pytorch
|
||||||
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
|
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
# pylint: disable=line-too-long
|
||||||
if len(sys.argv) == 5:
|
print("Should be used as `pytorch_pretrained_bert gpt OPENAI_GPT_CHECKPOINT_FOLDER_PATH PYTORCH_DUMP_OUTPUT [OPENAI_GPT_CONFIG]`")
|
||||||
OPENAI_GPT_CONFIG = sys.argv[4]
|
|
||||||
else:
|
else:
|
||||||
OPENAI_GPT_CONFIG = ""
|
OPENAI_GPT_CHECKPOINT_FOLDER_PATH = sys.argv[2]
|
||||||
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
|
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||||
OPENAI_GPT_CONFIG,
|
if len(sys.argv) == 5:
|
||||||
PYTORCH_DUMP_OUTPUT)
|
OPENAI_GPT_CONFIG = sys.argv[4]
|
||||||
elif sys.argv[1] == "convert_transfo_xl_checkpoint":
|
else:
|
||||||
|
OPENAI_GPT_CONFIG = ""
|
||||||
|
convert_openai_checkpoint_to_pytorch(OPENAI_GPT_CHECKPOINT_FOLDER_PATH,
|
||||||
|
OPENAI_GPT_CONFIG,
|
||||||
|
PYTORCH_DUMP_OUTPUT)
|
||||||
|
elif sys.argv[1] == "transfo_xl":
|
||||||
try:
|
try:
|
||||||
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
from .convert_transfo_xl_checkpoint_to_pytorch import convert_transfo_xl_checkpoint_to_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -50,20 +50,23 @@ def main():
|
|||||||
"In that case, it requires TensorFlow to be installed. Please see "
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
raise
|
raise
|
||||||
|
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||||
if 'ckpt' in sys.argv[2].lower():
|
# pylint: disable=line-too-long
|
||||||
TF_CHECKPOINT = sys.argv[2]
|
print("Should be used as `pytorch_pretrained_bert transfo_xl TF_CHECKPOINT/TF_DATASET_FILE PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
|
||||||
TF_DATASET_FILE = ""
|
|
||||||
else:
|
else:
|
||||||
TF_DATASET_FILE = sys.argv[2]
|
if 'ckpt' in sys.argv[2].lower():
|
||||||
TF_CHECKPOINT = ""
|
TF_CHECKPOINT = sys.argv[2]
|
||||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
TF_DATASET_FILE = ""
|
||||||
if len(sys.argv) == 5:
|
else:
|
||||||
TF_CONFIG = sys.argv[4]
|
TF_DATASET_FILE = sys.argv[2]
|
||||||
else:
|
TF_CHECKPOINT = ""
|
||||||
TF_CONFIG = ""
|
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||||
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
|
if len(sys.argv) == 5:
|
||||||
else:
|
TF_CONFIG = sys.argv[4]
|
||||||
|
else:
|
||||||
|
TF_CONFIG = ""
|
||||||
|
convert_transfo_xl_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT, TF_DATASET_FILE)
|
||||||
|
elif sys.argv[1] == "gpt2":
|
||||||
try:
|
try:
|
||||||
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
|
from .convert_gpt2_checkpoint_to_pytorch import convert_gpt2_checkpoint_to_pytorch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -72,12 +75,40 @@ def main():
|
|||||||
"https://www.tensorflow.org/install/ for installation instructions.")
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
TF_CHECKPOINT = sys.argv[2]
|
if len(sys.argv) < 4 or len(sys.argv) > 5:
|
||||||
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
# pylint: disable=line-too-long
|
||||||
if len(sys.argv) == 5:
|
print("Should be used as `pytorch_pretrained_bert gpt2 TF_CHECKPOINT PYTORCH_DUMP_OUTPUT [TF_CONFIG]`")
|
||||||
TF_CONFIG = sys.argv[4]
|
|
||||||
else:
|
else:
|
||||||
TF_CONFIG = ""
|
TF_CHECKPOINT = sys.argv[2]
|
||||||
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
PYTORCH_DUMP_OUTPUT = sys.argv[3]
|
||||||
|
if len(sys.argv) == 5:
|
||||||
|
TF_CONFIG = sys.argv[4]
|
||||||
|
else:
|
||||||
|
TF_CONFIG = ""
|
||||||
|
convert_gpt2_checkpoint_to_pytorch(TF_CHECKPOINT, TF_CONFIG, PYTORCH_DUMP_OUTPUT)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from .convert_xlnet_checkpoint_to_pytorch import convert_xlnet_checkpoint_to_pytorch
|
||||||
|
except ImportError:
|
||||||
|
print("pytorch_pretrained_bert can only be used from the commandline to convert TensorFlow models in PyTorch, "
|
||||||
|
"In that case, it requires TensorFlow to be installed. Please see "
|
||||||
|
"https://www.tensorflow.org/install/ for installation instructions.")
|
||||||
|
raise
|
||||||
|
|
||||||
|
if len(sys.argv) < 5 or len(sys.argv) > 6:
|
||||||
|
# pylint: disable=line-too-long
|
||||||
|
print("Should be used as `pytorch_pretrained_bert xlnet TF_CHECKPOINT TF_CONFIG PYTORCH_DUMP_OUTPUT [FINETUNING_TASK_NAME]`")
|
||||||
|
else:
|
||||||
|
TF_CHECKPOINT = sys.argv[2]
|
||||||
|
TF_CONFIG = sys.argv[3]
|
||||||
|
PYTORCH_DUMP_OUTPUT = sys.argv[4]
|
||||||
|
if len(sys.argv) == 6:
|
||||||
|
FINETUNING_TASK = sys.argv[5]
|
||||||
|
|
||||||
|
convert_xlnet_checkpoint_to_pytorch(TF_CHECKPOINT,
|
||||||
|
TF_CONFIG,
|
||||||
|
PYTORCH_DUMP_OUTPUT,
|
||||||
|
FINETUNING_TASK)
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
@ -70,7 +70,7 @@ if __name__ == "__main__":
|
|||||||
required = True,
|
required = True,
|
||||||
help = "The config json file corresponding to the pre-trained XLNet model. \n"
|
help = "The config json file corresponding to the pre-trained XLNet model. \n"
|
||||||
"This specifies the model architecture.")
|
"This specifies the model architecture.")
|
||||||
parser.add_argument("--pytorch_dump_folder_path",finetuning_task
|
parser.add_argument("--pytorch_dump_folder_path",
|
||||||
default = None,
|
default = None,
|
||||||
type = str,
|
type = str,
|
||||||
required = True,
|
required = True,
|
||||||
@ -81,6 +81,6 @@ if __name__ == "__main__":
|
|||||||
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
|
help = "Name of a task on which the XLNet TensorFloaw model was fine-tuned")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
convert_xlnet_checkpoint_to_pytorch(args.tf_checkpoint_path,
|
||||||
args.xlnet_config_file,
|
args.xlnet_config_file,
|
||||||
args.pytorch_dump_folder_path,
|
args.pytorch_dump_folder_path,
|
||||||
args.finetuning_task)
|
args.finetuning_task)
|
||||||
|
Loading…
Reference in New Issue
Block a user