mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
update pt-tf conversion script
This commit is contained in:
parent
f6969cc12b
commit
6a083fd447
@ -102,7 +102,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
|
||||
tf_model.save_weights(tf_dump_path, save_format='h5')
|
||||
|
||||
|
||||
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False):
|
||||
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False, use_cached_models=False):
|
||||
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
|
||||
|
||||
if args_model_type is None:
|
||||
@ -126,8 +126,8 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
|
||||
if 'finetuned' in shortcut_name:
|
||||
print(" Skipping finetuned checkpoint ")
|
||||
continue
|
||||
config_file = cached_path(aws_config_map[shortcut_name], force_download=True)
|
||||
model_file = cached_path(aws_model_maps[shortcut_name], force_download=True)
|
||||
config_file = cached_path(aws_config_map[shortcut_name], force_download=not use_cached_models)
|
||||
model_file = cached_path(aws_model_maps[shortcut_name], force_download=not use_cached_models)
|
||||
|
||||
convert_pt_checkpoint_to_tf(model_type,
|
||||
model_file,
|
||||
@ -165,6 +165,9 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--compare_with_pt_model",
|
||||
action='store_true',
|
||||
help = "Compare Tensorflow and PyTorch model predictions.")
|
||||
parser.add_argument("--use_cached_models",
|
||||
action='store_true',
|
||||
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.pytorch_checkpoint_path is not None:
|
||||
@ -176,4 +179,5 @@ if __name__ == "__main__":
|
||||
else:
|
||||
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
|
||||
args.tf_dump_path,
|
||||
compare_with_pt_model=args.compare_with_pt_model)
|
||||
compare_with_pt_model=args.compare_with_pt_model,
|
||||
use_cached_models=args.use_cached_models)
|
||||
|
@ -78,6 +78,12 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
|
||||
for old_key, new_key in zip(old_keys, new_keys):
|
||||
pt_state_dict[new_key] = pt_state_dict.pop(old_key)
|
||||
|
||||
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
|
||||
# TF models always have a prefix, some of PyTorch models (base ones) don't
|
||||
start_prefix_to_remove = ''
|
||||
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
|
||||
start_prefix_to_remove = tf_model.base_model_prefix + '.'
|
||||
|
||||
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
|
||||
|
||||
weight_value_tuples = []
|
||||
@ -100,13 +106,23 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
|
||||
if name[-1] == 'beta':
|
||||
name[-1] = 'bias'
|
||||
|
||||
# Remove prefix if needed
|
||||
name = '.'.join(name)
|
||||
if start_prefix_to_remove:
|
||||
name = name.replace(start_prefix_to_remove, '', 1)
|
||||
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
assert name in pt_state_dict, "{} not found in PyTorch model".format(name)
|
||||
array = pt_state_dict[name].numpy()
|
||||
|
||||
if transpose:
|
||||
array = numpy.transpose(array)
|
||||
|
||||
if len(symbolic_weight.shape) < len(array.shape):
|
||||
array = numpy.squeeze(array)
|
||||
elif len(symbolic_weight.shape) > len(array.shape):
|
||||
array = numpy.expand_dims(array, axis=0)
|
||||
|
||||
try:
|
||||
assert list(symbolic_weight.shape) == list(array.shape)
|
||||
except AssertionError as e:
|
||||
|
Loading…
Reference in New Issue
Block a user