mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
* fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
This commit is contained in:
parent
c5e29d4381
commit
6c57ce1558
@ -273,6 +273,18 @@ def load_pytorch_state_dict_in_tf2_model(
|
||||
new_key = key.replace("running_var", "moving_variance")
|
||||
if "running_mean" in key:
|
||||
new_key = key.replace("running_mean", "moving_mean")
|
||||
|
||||
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
|
||||
key_components = key.split(".")
|
||||
name = None
|
||||
if key_components[-3::2] == ["parametrizations", "original0"]:
|
||||
name = key_components[-2] + "_g"
|
||||
elif key_components[-3::2] == ["parametrizations", "original1"]:
|
||||
name = key_components[-2] + "_v"
|
||||
if name is not None:
|
||||
key_components = key_components[:-3] + [name]
|
||||
new_key = ".".join(key_components)
|
||||
|
||||
if new_key is None:
|
||||
new_key = key
|
||||
tf_keys_to_pt_keys[new_key] = key
|
||||
@ -499,15 +511,27 @@ def load_tf2_state_dict_in_pytorch_model(pt_model, tf_state_dict, allow_missing_
|
||||
new_pt_params_dict[pt_weight_name] = loaded_pt_weights_data_ptr[pt_weight.data_ptr()]
|
||||
continue
|
||||
|
||||
pt_weight_name_to_check = pt_weight_name
|
||||
# New `weight_norm` from https://github.com/huggingface/transformers/pull/24030
|
||||
key_components = pt_weight_name.split(".")
|
||||
name = None
|
||||
if key_components[-3::2] == ["parametrizations", "original0"]:
|
||||
name = key_components[-2] + "_g"
|
||||
elif key_components[-3::2] == ["parametrizations", "original1"]:
|
||||
name = key_components[-2] + "_v"
|
||||
if name is not None:
|
||||
key_components = key_components[:-3] + [name]
|
||||
pt_weight_name_to_check = ".".join(key_components)
|
||||
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
if pt_weight_name not in tf_weights_map:
|
||||
if pt_weight_name_to_check not in tf_weights_map:
|
||||
if allow_missing_keys:
|
||||
missing_keys_pt.append(pt_weight_name)
|
||||
continue
|
||||
|
||||
raise AttributeError(f"{pt_weight_name} not found in TF 2.0 model")
|
||||
|
||||
array, transpose = tf_weights_map[pt_weight_name]
|
||||
array, transpose = tf_weights_map[pt_weight_name_to_check]
|
||||
|
||||
array = apply_transpose(transpose, array, pt_weight.shape, pt_to_tf=False)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user