mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
fix bert layernorm
This commit is contained in:
parent
d3a3a0353c
commit
dcddf498c8
@ -62,6 +62,19 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
|
||||
"https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.")
|
||||
raise e
|
||||
|
||||
# Adapt state dict - TODO remove this and update the AWS weights files instead
|
||||
for key in pt_state_dict.keys():
|
||||
new_key = None
|
||||
if 'gamma' in key:
|
||||
new_key = key.replace('gamma', 'weight')
|
||||
if 'beta' in key:
|
||||
new_key = key.replace('beta', 'bias')
|
||||
if new_key:
|
||||
old_keys.append(key)
|
||||
new_keys.append(new_key)
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
pt_state_dict[new_key] = pt_state_dict.pop(old_key)
|
||||
|
||||
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
|
||||
|
||||
weight_value_tuples = []
|
||||
|
Loading…
Reference in New Issue
Block a user