mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix convert for newer megatron-lm bert model (#14082)
* Fix convert for newer megatron-lm models * Save megatron-bert config in a proper way * Fix code style
This commit is contained in:
parent
623b4f7c63
commit
768e6c1449
@ -33,13 +33,14 @@
|
||||
#
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import zipfile
|
||||
|
||||
import torch
|
||||
|
||||
from transformers import MegatronBertConfig
|
||||
|
||||
|
||||
####################################################################################################
|
||||
|
||||
@ -64,13 +65,62 @@ def recursive_print(name, val, spaces=0):
|
||||
print(msg, ":", val)
|
||||
|
||||
|
||||
def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
|
||||
# Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :]
|
||||
# for compatibility with later versions of NVIDIA Megatron-LM.
|
||||
# The inverse operation is performed inside Megatron-LM to read checkpoints:
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
|
||||
# If param is the weight tensor of the self-attention block, the returned tensor
|
||||
# will have to be transposed one more time to be read by HuggingFace BERT.
|
||||
input_shape = param.size()
|
||||
if checkpoint_version == 1.0:
|
||||
# version 1.0 stores [num_heads * hidden_size * num_splits, :]
|
||||
saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
|
||||
param = param.view(*saved_shape)
|
||||
param = param.transpose(0, 2)
|
||||
param = param.transpose(1, 2).contiguous()
|
||||
elif checkpoint_version >= 2.0:
|
||||
# other versions store [num_heads * num_splits * hidden_size, :]
|
||||
saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
|
||||
param = param.view(*saved_shape)
|
||||
param = param.transpose(0, 1).contiguous()
|
||||
param = param.view(*input_shape)
|
||||
return param
|
||||
|
||||
|
||||
####################################################################################################
|
||||
|
||||
|
||||
def convert_megatron_checkpoint(args, input_state_dict):
|
||||
def convert_megatron_checkpoint(args, input_state_dict, config):
|
||||
# The converted output model.
|
||||
output_state_dict = {}
|
||||
|
||||
# old versions did not store training args
|
||||
ds_args = input_state_dict.get("args", None)
|
||||
if ds_args is not None:
|
||||
# do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
|
||||
# from pprint import pprint
|
||||
# pprint(vars(ds_args))
|
||||
|
||||
config.tokenizer_type = ds_args.tokenizer_type
|
||||
config.vocab_size = ds_args.padded_vocab_size
|
||||
config.max_position_embeddings = ds_args.max_position_embeddings
|
||||
config.hidden_size = ds_args.hidden_size
|
||||
config.num_hidden_layers = ds_args.num_layers
|
||||
config.num_attention_heads = ds_args.num_attention_heads
|
||||
config.intermediate_size = ds_args.get("ffn_hidden_size", 4 * ds_args.hidden_size)
|
||||
# pprint(config)
|
||||
|
||||
# The number of heads.
|
||||
heads = config.num_attention_heads
|
||||
# The hidden_size per head.
|
||||
hidden_size_per_head = config.hidden_size // heads
|
||||
# Megatron-LM checkpoint version
|
||||
if "checkpoint_version" in input_state_dict.keys():
|
||||
checkpoint_version = input_state_dict["checkpoint_version"]
|
||||
else:
|
||||
checkpoint_version = 0.0
|
||||
|
||||
# The model.
|
||||
model = input_state_dict["model"]
|
||||
# The language model.
|
||||
@ -80,13 +130,14 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
||||
|
||||
# The word embeddings.
|
||||
word_embeddings = embeddings["word_embeddings"]["weight"]
|
||||
# Truncate the embedding table to vocab_size rows.
|
||||
word_embeddings = word_embeddings[: config.vocab_size, :]
|
||||
# Store the word embeddings.
|
||||
output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings
|
||||
|
||||
# The position embeddings.
|
||||
pos_embeddings = embeddings["position_embeddings"]["weight"]
|
||||
# Trained for 512 x 1024.
|
||||
assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024
|
||||
assert pos_embeddings.size(0) == config.max_position_embeddings and pos_embeddings.size(1) == config.hidden_size
|
||||
# Store the position embeddings.
|
||||
output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings
|
||||
|
||||
@ -96,7 +147,7 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
||||
output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings
|
||||
|
||||
# The transformer.
|
||||
transformer = lm["transformer"]
|
||||
transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]
|
||||
|
||||
# The regex to extract layer names.
|
||||
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
|
||||
@ -142,8 +193,9 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
||||
# Make sure the QKV pointer is nil.
|
||||
assert attention_qkv_weight is None, ""
|
||||
|
||||
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
|
||||
# Store the tensor as we need the bias as well to interleave QKV and biases.
|
||||
attention_qkv_weight = val
|
||||
attention_qkv_weight = out_val
|
||||
|
||||
# Transpose the bias.
|
||||
elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
|
||||
@ -152,14 +204,15 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
||||
assert attention_qkv_weight is not None, ""
|
||||
|
||||
# Split the QKV matrix into Q, K and V. Megatron stores Q,K,V interleaved.
|
||||
q = attention_qkv_weight[0 * 1024 : 1 * 1024, :]
|
||||
k = attention_qkv_weight[1 * 1024 : 2 * 1024, :]
|
||||
v = attention_qkv_weight[2 * 1024 : 3 * 1024, :]
|
||||
q = attention_qkv_weight[0 * config.hidden_size : 1 * config.hidden_size, :]
|
||||
k = attention_qkv_weight[1 * config.hidden_size : 2 * config.hidden_size, :]
|
||||
v = attention_qkv_weight[2 * config.hidden_size : 3 * config.hidden_size, :]
|
||||
|
||||
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
|
||||
# Split the bias.
|
||||
q_bias = val[0 * 1024 : 1 * 1024]
|
||||
k_bias = val[1 * 1024 : 2 * 1024]
|
||||
v_bias = val[2 * 1024 : 3 * 1024]
|
||||
q_bias = out_val[0 * config.hidden_size : 1 * config.hidden_size]
|
||||
k_bias = out_val[1 * config.hidden_size : 2 * config.hidden_size]
|
||||
v_bias = out_val[2 * config.hidden_size : 3 * config.hidden_size]
|
||||
|
||||
# Store.
|
||||
output_state_dict[f"{layer_name}.attention.self.query.weight"] = q
|
||||
@ -182,24 +235,6 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
||||
output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
|
||||
output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]
|
||||
|
||||
# The config.
|
||||
output_config = {
|
||||
"vocab_size": word_embeddings.size(0),
|
||||
"hidden_size": 1024,
|
||||
"num_hidden_layers": 24,
|
||||
"num_attention_heads": 16,
|
||||
"hidden_act": "gelu_new",
|
||||
"intermediate_size": 4096,
|
||||
"hidden_dropout_prob": 0.1,
|
||||
"attention_probs_dropout_prob": 0.1,
|
||||
"max_position_embeddings": 512,
|
||||
"type_vocab_size": 2,
|
||||
"initializer_range": 0.2,
|
||||
"layer_norm_eps": 1e-12,
|
||||
"position_embedding_type": "absolute",
|
||||
"use_cache": False,
|
||||
}
|
||||
|
||||
# The pooler.
|
||||
pooler = lm["pooler"]
|
||||
|
||||
@ -230,7 +265,7 @@ def convert_megatron_checkpoint(args, input_state_dict):
|
||||
output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]
|
||||
|
||||
# It should be done!
|
||||
return output_state_dict, output_config
|
||||
return output_state_dict
|
||||
|
||||
|
||||
####################################################################################################
|
||||
@ -241,30 +276,44 @@ def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--print-checkpoint-structure", action="store_true")
|
||||
parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint")
|
||||
parser.add_argument(
|
||||
"--config_file",
|
||||
default="",
|
||||
type=str,
|
||||
help="An optional config json file describing the pre-trained model.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Extract the basename.
|
||||
basename = os.path.dirname(args.path_to_checkpoint)
|
||||
|
||||
# Load the model.
|
||||
# the .zip is very optional, let's keep it for backward compatibility
|
||||
print(f'Extracting PyTorch state dictionary from "{args.path_to_checkpoint}"')
|
||||
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
|
||||
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
|
||||
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
|
||||
if args.path_to_checkpoint.endswith(".zip"):
|
||||
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
|
||||
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
|
||||
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
|
||||
else:
|
||||
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
|
||||
|
||||
if args.config_file == "":
|
||||
# Default config of megatron-bert 345m
|
||||
config = MegatronBertConfig()
|
||||
else:
|
||||
config = MegatronBertConfig.from_json_file(args.config_file)
|
||||
|
||||
# Convert.
|
||||
print("Converting")
|
||||
output_state_dict, output_config = convert_megatron_checkpoint(args, input_state_dict)
|
||||
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)
|
||||
|
||||
# Print the structure of converted state dict.
|
||||
if args.print_checkpoint_structure:
|
||||
recursive_print(None, output_state_dict)
|
||||
|
||||
# Store the config to file.
|
||||
output_config_file = os.path.join(basename, "config.json")
|
||||
print(f'Saving config to "{output_config_file}"')
|
||||
with open(output_config_file, "w") as f:
|
||||
json.dump(output_config, f)
|
||||
print("Saving config")
|
||||
config.save_pretrained(basename)
|
||||
|
||||
# Store the state_dict to file.
|
||||
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")
|
||||
|
Loading…
Reference in New Issue
Block a user