Update PT/TF weight conversion after #24030 (#24547)

* fix

* fix

* fix

---------

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
Yih-Dar 2023-06-28 16:36:57 +02:00 committed by GitHub
parent c5e29d4381
commit 6c57ce1558
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -273,6 +273,18 @@ def load_pytorch_state_dict_in_tf2_model(
new_key = key.replace("running_var", "moving_variance")
if "running_mean" in key:
new_key = key.replace("running_mean", "moving_mean")
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
key_components = key.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
new_key = ".".join(key_components)
if new_key is None:
new_key = key
tf_keys_to_pt_keys[new_key] = key
@ -499,15 +511,27 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
continue
pt_weight_name_to_check = pt_weight_name
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
key_components = pt_weight_name.split(".")
name = None
if key_components[-3::2] == ["parametrizations", "original0"]:
name = key_components[-2] + "_g"
elif key_components[-3::2] == ["parametrizations", "original1"]:
name = key_components[-2] + "_v"
if name is not None:
key_components = key_components[:-3] + [name]
pt_weight_name_to_check = ".".join(key_components)
# Find associated numpy array in pytorch model state dict
if pt_weight_name not in tf_weights_map:
if pt_weight_name_to_check not in tf_weights_map:
if allow_missing_keys:
missing_keys_pt.append(pt_weight_name)
continue
raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")
array, transpose = tf_weights_map[pt_weight_name]
array, transpose = tf_weights_map[pt_weight_name_to_check]
array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)