mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
fix megatron bert convert state dict naming (#15820)
This commit is contained in:
parent
9a2995ee39
commit
33cd4be576
@ -155,6 +155,7 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
|
||||
# The simple map of names for "automated" rules.
|
||||
megatron_to_transformers = {
|
||||
"attention.dense": ".attention.output.dense.",
|
||||
"self_attention.dense": ".attention.output.dense.",
|
||||
"mlp.dense_h_to_4h": ".intermediate.dense.",
|
||||
"mlp.dense_4h_to_h": ".output.dense.",
|
||||
}
|
||||
@ -188,7 +189,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
|
||||
output_state_dict[layer_name + "." + ln_name + "." + weight_or_bias] = val
|
||||
|
||||
# Transpose the QKV matrix.
|
||||
elif op_name == "attention.query_key_value" and weight_or_bias == "weight":
|
||||
elif (
|
||||
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
|
||||
) and weight_or_bias == "weight":
|
||||
|
||||
# Make sure the QKV pointer is nil.
|
||||
assert attention_qkv_weight is None, ""
|
||||
@ -198,7 +201,9 @@ def convert_megatron_checkpoint(args, input_state_dict, config):
|
||||
attention_qkv_weight = out_val
|
||||
|
||||
# Transpose the bias.
|
||||
elif op_name == "attention.query_key_value" and weight_or_bias == "bias":
|
||||
elif (
|
||||
op_name == "attention.query_key_value" or op_name == "self_attention.query_key_value"
|
||||
) and weight_or_bias == "bias":
|
||||
|
||||
# Make sure we read the weight tensor.
|
||||
assert attention_qkv_weight is not None, ""
|
||||
|
Loading…
Reference in New Issue
Block a user