Fix convert for newer megatron-lm bert model (#14082)

* Fix convert for newer megatron-lm models

* Save megatron-bert config in a proper way

* Fix code style
This commit is contained in:
yoquankara 2022-01-09 04:33:55 +09:00 committed by GitHub
parent 623b4f7c63
commit 768e6c1449
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -33,13 +33,14 @@
#
import argparse
import json
import os
import re
import zipfile
import torch
from transformers import MegatronBertConfig
####################################################################################################
@ -64,13 +65,62 @@ def recursive_print(name, val, spaces=0):
print(msg, ":", val)
def fix_query_key_value_ordering(param, checkpoint_version, num_splits, num_heads, hidden_size):
# Permutes layout of param tensor to [num_splits * num_heads * hidden_size, :]
# for compatibility with later versions of NVIDIA Megatron-LM.
# The inverse operation is performed inside Megatron-LM to read checkpoints:
# https://github.com/NVIDIA/Megatron-LM/blob/v2.4/megatron/checkpointing.py#L209
# If param is the weight tensor of the self-attention block, the returned tensor
# will have to be transposed one more time to be read by HuggingFace BERT.
input_shape = param.size()
if checkpoint_version == 1.0:
# version 1.0 stores [num_heads * hidden_size * num_splits, :]
saved_shape = (num_heads, hidden_size, num_splits) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 2)
param = param.transpose(1, 2).contiguous()
elif checkpoint_version >= 2.0:
# other versions store [num_heads * num_splits * hidden_size, :]
saved_shape = (num_heads, num_splits, hidden_size) + input_shape[1:]
param = param.view(*saved_shape)
param = param.transpose(0, 1).contiguous()
param = param.view(*input_shape)
return param
####################################################################################################
def convert_megatron_checkpoint(args, input_state_dict):
def convert_megatron_checkpoint(args, input_state_dict, config):
# The converted output model.
output_state_dict = {}
# old versions did not store training args
ds_args = input_state_dict.get("args", None)
if ds_args is not None:
# do not make the user write a config file when the exact dimensions/sizes are already in the checkpoint
# from pprint import pprint
# pprint(vars(ds_args))
config.tokenizer_type = ds_args.tokenizer_type
config.vocab_size = ds_args.padded_vocab_size
config.max_position_embeddings = ds_args.max_position_embeddings
config.hidden_size = ds_args.hidden_size
config.num_hidden_layers = ds_args.num_layers
config.num_attention_heads = ds_args.num_attention_heads
config.intermediate_size = ds_args.get("ffn_hidden_size", 4 * ds_args.hidden_size)
# pprint(config)
# The number of heads.
heads = config.num_attention_heads
# The hidden_size per head.
hidden_size_per_head = config.hidden_size // heads
# Megatron-LM checkpoint version
if "checkpoint_version" in input_state_dict.keys():
checkpoint_version = input_state_dict["checkpoint_version"]
else:
checkpoint_version = 0.0
# The model.
model = input_state_dict["model"]
# The language model.
@ -80,13 +130,14 @@ def convert_megatron_checkpoint(args, input_state_dict):
# The word embeddings.
word_embeddings = embeddings["word_embeddings"]["weight"]
# Truncate the embedding table to vocab_size rows.
word_embeddings = word_embeddings[: config.vocab_size, :]
# Store the word embeddings.
output_state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings
# The position embeddings.
pos_embeddings = embeddings["position_embeddings"]["weight"]
# Trained for 512 x 1024.
assert pos_embeddings.size(0) == 512 and pos_embeddings.size(1) == 1024
assert pos_embeddings.size(0) == config.max_position_embeddings and pos_embeddings.size(1) == config.hidden_size
# Store the position embeddings.
output_state_dict["bert.embeddings.position_embeddings.weight"] = pos_embeddings
@ -96,7 +147,7 @@ def convert_megatron_checkpoint(args, input_state_dict):
output_state_dict["bert.embeddings.token_type_embeddings.weight"] = tokentype_embeddings
# The transformer.
transformer = lm["transformer"]
transformer = lm["transformer"] if "transformer" in lm.keys() else lm["encoder"]
# The regex to extract layer names.
layer_re = re.compile("layers\.(\d+)\.([a-z0-9_.]+)\.([a-z]+)")
@ -142,8 +193,9 @@ def convert_megatron_checkpoint(args, input_state_dict):
# Make sure the QKV pointer is nil.
assert attention_qkv_weight is None, ""
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
# Store the tensor as we need the bias as well to interleave QKV and biases.
attention_qkv_weight = val
attention_qkv_weight = out_val
# Transpose the bias.
elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
@ -152,14 +204,15 @@ def convert_megatron_checkpoint(args, input_state_dict):
assert attention_qkv_weight is not None, ""
# Split the QKV matrix into Q, K and V. Megatron stores Q,K,V interleaved.
q = attention_qkv_weight[0 * 1024 : 1 * 1024, :]
k = attention_qkv_weight[1 * 1024 : 2 * 1024, :]
v = attention_qkv_weight[2 * 1024 : 3 * 1024, :]
q = attention_qkv_weight[0 * config.hidden_size : 1 * config.hidden_size, :]
k = attention_qkv_weight[1 * config.hidden_size : 2 * config.hidden_size, :]
v = attention_qkv_weight[2 * config.hidden_size : 3 * config.hidden_size, :]
out_val = fix_query_key_value_ordering(val, checkpoint_version, 3, heads, hidden_size_per_head)
# Split the bias.
q_bias = val[0 * 1024 : 1 * 1024]
k_bias = val[1 * 1024 : 2 * 1024]
v_bias = val[2 * 1024 : 3 * 1024]
q_bias = out_val[0 * config.hidden_size : 1 * config.hidden_size]
k_bias = out_val[1 * config.hidden_size : 2 * config.hidden_size]
v_bias = out_val[2 * config.hidden_size : 3 * config.hidden_size]
# Store.
output_state_dict[f"{layer_name}.attention.self.query.weight"] = q
@ -182,24 +235,6 @@ def convert_megatron_checkpoint(args, input_state_dict):
output_state_dict["bert.encoder.ln.weight"] = transformer["final_layernorm.weight"]
output_state_dict["bert.encoder.ln.bias"] = transformer["final_layernorm.bias"]
# The config.
output_config = {
"vocab_size": word_embeddings.size(0),
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"hidden_act": "gelu_new",
"intermediate_size": 4096,
"hidden_dropout_prob": 0.1,
"attention_probs_dropout_prob": 0.1,
"max_position_embeddings": 512,
"type_vocab_size": 2,
"initializer_range": 0.2,
"layer_norm_eps": 1e-12,
"position_embedding_type": "absolute",
"use_cache": False,
}
# The pooler.
pooler = lm["pooler"]
@ -230,7 +265,7 @@ def convert_megatron_checkpoint(args, input_state_dict):
output_state_dict["cls.seq_relationship.bias"] = binary_head["bias"]
# It should be done!
return output_state_dict, output_config
return output_state_dict
####################################################################################################
@ -241,30 +276,44 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument("--print-checkpoint-structure", action="store_true")
parser.add_argument("path_to_checkpoint", type=str, help="Path to the ZIP file containing the checkpoint")
parser.add_argument(
"--config_file",
default="",
type=str,
help="An optional config json file describing the pre-trained model.",
)
args = parser.parse_args()
# Extract the basename.
basename = os.path.dirname(args.path_to_checkpoint)
# Load the model.
# the .zip is very optional, let's keep it for backward compatibility
print(f'Extracting PyTorch state dictionary from "{args.path_to_checkpoint}"')
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
if args.path_to_checkpoint.endswith(".zip"):
with zipfile.ZipFile(args.path_to_checkpoint, "r") as checkpoint:
with checkpoint.open("release/mp_rank_00/model_optim_rng.pt") as pytorch_dict:
input_state_dict = torch.load(pytorch_dict, map_location="cpu")
else:
input_state_dict = torch.load(args.path_to_checkpoint, map_location="cpu")
if args.config_file == "":
# Default config of megatron-bert 345m
config = MegatronBertConfig()
else:
config = MegatronBertConfig.from_json_file(args.config_file)
# Convert.
print("Converting")
output_state_dict, output_config = convert_megatron_checkpoint(args, input_state_dict)
output_state_dict = convert_megatron_checkpoint(args, input_state_dict, config)
# Print the structure of converted state dict.
if args.print_checkpoint_structure:
recursive_print(None, output_state_dict)
# Store the config to file.
output_config_file = os.path.join(basename, "config.json")
print(f'Saving config to "{output_config_file}"')
with open(output_config_file, "w") as f:
json.dump(output_config, f)
print("Saving config")
config.save_pretrained(basename)
# Store the state_dict to file.
output_checkpoint_file = os.path.join(basename, "pytorch_model.bin")