mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[megatron_gpt2] dynamic gelu, add tokenizer, save config (#13928)
* [megatron_gpt2] dynamic gelu, add tokenizer, save config * cleanup * Update src/transformers/models/megatron_gpt2/convert_megatron_gpt2_checkpoint.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * apply suggestions Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
919a964b8f
commit
bfd8176636
@ -17,14 +17,13 @@
|
||||
####################################################################################################
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import GPT2Config
|
||||
from transformers import AutoTokenizer, GPT2Config
|
||||
|
||||
|
||||
####################################################################################################
|
||||
@ -81,19 +80,19 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
|
||||
output_state_dict = {}
|
||||
|
||||
# old versions did not store training args
|
||||
if "args" in input_state_dict:
|
||||
ds_args = input_state_dict.get("args", None)
|
||||
if ds_args is not None:
|
||||
# do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
|
||||
train_args = input_state_dict["args"]
|
||||
# from pprint import pprint
|
||||
# pprint(vars(train_args))
|
||||
# pprint(vars(ds_args))
|
||||
|
||||
config.vocab_size = train_args.padded_vocab_size
|
||||
config.n_positions = train_args.max_position_embeddings
|
||||
config.n_ctx = train_args.seq_length
|
||||
config.n_embd = train_args.hidden_size
|
||||
config.n_layer = train_args.num_layers
|
||||
config.n_head = train_args.num_attention_heads
|
||||
config.n_inner = train_args.ffn_hidden_size
|
||||
config.vocab_size = ds_args.padded_vocab_size
|
||||
config.n_positions = ds_args.max_position_embeddings
|
||||
config.n_ctx = ds_args.seq_length
|
||||
config.n_embd = ds_args.hidden_size
|
||||
config.n_layer = ds_args.num_layers
|
||||
config.n_head = ds_args.num_attention_heads
|
||||
config.n_inner = ds_args.ffn_hidden_size
|
||||
# pprint(config)
|
||||
|
||||
# The number of heads.
|
||||
@ -255,8 +254,22 @@ def main():
|
||||
else:
|
||||
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
|
||||
|
||||
ds_args = input_state_dict.get("args", None)
|
||||
|
||||
# Read the config, or default to the model released by NVIDIA.
|
||||
if args.config_file == "":
|
||||
|
||||
if ds_args is not None:
|
||||
if ds_args.bias_gelu_fusion:
|
||||
activation_function = "gelu_fast"
|
||||
elif ds_args.openai_gelu:
|
||||
activation_function = "gelu_new"
|
||||
else:
|
||||
activation_function = "gelu"
|
||||
else:
|
||||
# in the very early days this used to be "gelu_new"
|
||||
activation_function = "gelu_new"
|
||||
|
||||
# Spell out all parameters in case the defaults change.
|
||||
config = GPT2Config(
|
||||
vocab_size=50257,
|
||||
@ -266,7 +279,7 @@ def main():
|
||||
n_layer=24,
|
||||
n_head=16,
|
||||
n_inner=4096,
|
||||
activation_function="gelu", # used to be "gelu_new" in earlier versions
|
||||
activation_function=activation_function,
|
||||
resid_pdrop=0.1,
|
||||
embd_pdrop=0.1,
|
||||
attn_pdrop=0.1,
|
||||
@ -285,6 +298,8 @@ def main():
|
||||
else:
|
||||
config = GPT2Config.from_json_file(args.config_file)
|
||||
|
||||
config.architectures = ["GPT2LMHeadModel"]
|
||||
|
||||
# Convert.
|
||||
print("Converting")
|
||||
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)
|
||||
@ -293,14 +308,30 @@ def main():
|
||||
if args.print_checkpoint_structure:
|
||||
recursive_print(None, output_state_dict)
|
||||
|
||||
# Add tokenizer class info to config
|
||||
# see https://github.com/huggingface/transformers/issues/13906)
|
||||
if ds_args is not None:
|
||||
tokenizer_type = ds_args.tokenizer_type
|
||||
if tokenizer_type == "GPT2BPETokenizer":
|
||||
tokenizer_model_name = "gpt2"
|
||||
elif tokenizer_type == "PretrainedFromHF":
|
||||
tokenizer_model_name = ds_args.tokenizer_name_or_path
|
||||
else:
|
||||
raise ValueError(f"Unrecognized tokenizer_type {tokenizer_type}")
|
||||
else:
|
||||
tokenizer_model_name = "gpt2"
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(tokenizer_model_name)
|
||||
tokenizer_class = type(tokenizer).__name__
|
||||
config.tokenizer_class = tokenizer_class
|
||||
|
||||
# Store the config to file.
|
||||
output_config_file = os.path.join(basename, "config.json")
|
||||
output_config = config.to_dict()
|
||||
output_config["architectures"] = ["GPT2LMHeadModel"]
|
||||
output_config["model_type"] = "gpt2"
|
||||
print(f'Saving config to "{output_config_file}"')
|
||||
with open(output_config_file, "w") as f:
|
||||
json.dump(output_config, f)
|
||||
print("Saving config")
|
||||
config.save_pretrained(basename)
|
||||
|
||||
# Save tokenizer based on args
|
||||
print(f"Adding {tokenizer_class} tokenizer files")
|
||||
tokenizer.save_pretrained(basename)
|
||||
|
||||
# Store the state_dict to file.
|
||||
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
|
||||
|
Loading…
Reference in New Issue
Block a user