From dcddf498c8b34cfeb99fb563e771877a1bc7ded5 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Thu, 12 Sep 2019 16:46:32 +0200 Subject: [PATCH] fix bert layernorm --- pytorch_transformers/modeling_tf_pytorch_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/pytorch_transformers/modeling_tf_pytorch_utils.py b/pytorch_transformers/modeling_tf_pytorch_utils.py index 67845969449..d979c0e1a4a 100644 --- a/pytorch_transformers/modeling_tf_pytorch_utils.py +++ b/pytorch_transformers/modeling_tf_pytorch_utils.py @@ -62,6 +62,19 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None "https://pytorch.org/ and https://www.tensorflow.org/install/ for installation instructions.") raise e + # Adapt state dict - TODO remove this and update the AWS weights files instead + for key in pt_state_dict.keys(): + new_key = None + if 'gamma' in key: + new_key = key.replace('gamma', 'weight') + if 'beta' in key: + new_key = key.replace('beta', 'bias') + if new_key: + old_keys.append(key) + new_keys.append(new_key) + for old_key, new_key in zip(old_keys, new_keys): + pt_state_dict[new_key] = pt_state_dict.pop(old_key) + symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights weight_value_tuples = []