Merge pull request #1064 from huggingface/gpt-2-large

Adding gpt-2 large (774M parameters) model
This commit is contained in:
Thomas Wolf 2019-08-21 03:05:56 +02:00 committed by GitHub
commit 07681b6b58
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 10 additions and 5 deletions

View File

@ -35,7 +35,7 @@ def convert_gpt2_checkpoint_to_pytorch(gpt2_checkpoint_path, gpt2_config_file, p
if gpt2_config_file == "":
config = GPT2Config()
else:
config = GPT2Config(gpt2_config_file)
config = GPT2Config.from_json_file(gpt2_config_file)
model = GPT2Model(config)
# Load weights from numpy

View File

@ -35,7 +35,7 @@ def convert_openai_checkpoint_to_pytorch(openai_checkpoint_folder_path, openai_c
if openai_config_file == "":
config = OpenAIGPTConfig()
else:
config = OpenAIGPTConfig(openai_config_file)
config = OpenAIGPTConfig.from_json_file(openai_config_file)
model = OpenAIGPTModel(config)
# Load weights from numpy

View File

@ -75,7 +75,7 @@ def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
if transfo_xl_config_file == "":
config = TransfoXLConfig()
else:
config = TransfoXLConfig(transfo_xl_config_file)
config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = TransfoXLLMHeadModel(config)

View File

@ -38,9 +38,11 @@ from .modeling_bert import BertLayerNorm as LayerNorm
logger = logging.getLogger(__name__)
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-pytorch_model.bin",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin"}
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-pytorch_model.bin",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-pytorch_model.bin"}
GPT2_PRETRAINED_CONFIG_ARCHIVE_MAP = {"gpt2": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json",
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json"}
"gpt2-medium": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-config.json",
"gpt2-large": "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-config.json"}
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
""" Load tf checkpoints in a pytorch model

View File

@ -45,17 +45,20 @@ PRETRAINED_VOCAB_FILES_MAP = {
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-vocab.json",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-vocab.json",
},
'merges_file':
{
'gpt2': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt",
'gpt2-medium': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-merges.txt",
'gpt2-large': "https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-merges.txt",
},
}
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
'gpt2': 1024,
'gpt2-medium': 1024,
'gpt2-large': 1024,
}
@lru_cache()