mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
adding gpt-2 large
This commit is contained in:
parent
e4515faf54
commit
aa05dc8935
@ -35,7 +35,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
|
||||
if gpt2_config_file == "":
|
||||
config = GPT2Config()
|
||||
else:
|
||||
config = GPT2Config(gpt2_config_file)
|
||||
config = GPT2Config.from_json_file(gpt2_config_file)
|
||||
model = GPT2Model(config)
|
||||
|
||||
# Load weights from numpy
|
||||
|
@ -35,7 +35,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
|
||||
if openai_config_file == "":
|
||||
config = OpenAIGPTConfig()
|
||||
else:
|
||||
config = OpenAIGPTConfig(openai_config_file)
|
||||
config = OpenAIGPTConfig.from_json_file(openai_config_file)
|
||||
model = OpenAIGPTModel(config)
|
||||
|
||||
# Load weights from numpy
|
||||
|
@ -75,7 +75,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
|
||||
if transfo_xl_config_file == "":
|
||||
config = TransfoXLConfig()
|
||||
else:
|
||||
config = TransfoXLConfig(transfo_xl_config_file)
|
||||
config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
|
||||
print("Building PyTorch model from configuration: {}".format(str(config)))
|
||||
model = TransfoXLLMHeadModel(config)
|
||||
|
||||
|
@ -38,9 +38,11 @@ from .modeling_bert import BertLayerNorm as LayerNorm
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
|
||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
|
||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
|
||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"}
|
||||
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
|
||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
|
||||
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
|
||||
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"}
|
||||
|
||||
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
||||
""" Load tf checkpoints in a pytorch model
|
||||
|
@ -45,11 +45,13 @@ PRETRAINED_VOCAB_FILES_MAP = {
|
||||
{
|
||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
|
||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
|
||||
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
|
||||
},
|
||||
'merges_file':
|
||||
{
|
||||
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
|
||||
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
|
||||
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
|
||||
},
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user