fix megatron bert convert state dict naming (#15820)

This commit is contained in:
Zhengqiang Yin 2022-04-18 23:34:36 +08:00 committed by GitHub
parent 9a2995ee39
commit 33cd4be576
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -155,6 +155,7 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
# The simple map of names for "automated" rules.
megatron_to_transformers = {
"attention.dense": ".attention.output.dense.",
"self_attention.dense": ".attention.output.dense.",
"mlp.dense_h_to_4h": ".intermediate.dense.",
"mlp.dense_4h_to_h": ".output.dense.",
}
@ -188,7 +189,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
# Transpose the QKV matrix.
elif op_name == "attention.query_key_value" and weight_or_bias == "weight":
elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "weight":
# Make sure the QKV pointer is nil.
assert attention_qkv_weight is None, ""
@ -198,7 +201,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
attention_qkv_weight = out_val
# Transpose the bias.
elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
elif (
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
) and weight_or_bias == "bias":
# Make sure we read the weight tensor.
assert attention_qkv_weight is not None, ""