update pt-tf conversion script

This commit is contained in:
thomwolf 2019-09-18 12:11:32 +02:00
parent f6969cc12b
commit 6a083fd447
2 changed files with 24 additions and 4 deletions

View File

@ -102,7 +102,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file
tf_model.save_weights(tf_dump_path, save_format='h5')
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False):
def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False, use_cached_models=False):
assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory"
if args_model_type is None:
@ -126,8 +126,8 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with
if 'finetuned' in shortcut_name:
print(" Skipping finetuned checkpoint ")
continue
config_file = cached_path(aws_config_map[shortcut_name], force_download=True)
model_file = cached_path(aws_model_maps[shortcut_name], force_download=True)
config_file = cached_path(aws_config_map[shortcut_name], force_download=not use_cached_models)
model_file = cached_path(aws_model_maps[shortcut_name], force_download=not use_cached_models)
convert_pt_checkpoint_to_tf(model_type,
model_file,
@ -165,6 +165,9 @@ if __name__ == "__main__":
parser.add_argument("--compare_with_pt_model",
action='store_true',
help = "Compare Tensorflow and PyTorch model predictions.")
parser.add_argument("--use_cached_models",
action='store_true',
help = "Use cached models if possible instead of updating to latest checkpoint versions.")
args = parser.parse_args()
if args.pytorch_checkpoint_path is not None:
@ -176,4 +179,5 @@ if __name__ == "__main__":
else:
convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None,
args.tf_dump_path,
compare_with_pt_model=args.compare_with_pt_model)
compare_with_pt_model=args.compare_with_pt_model,
use_cached_models=args.use_cached_models)

View File

@ -78,6 +78,12 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
for old_key, new_key in zip(old_keys, new_keys):
pt_state_dict[new_key] = pt_state_dict.pop(old_key)
# Make sure we are able to load PyTorch base models as well as derived models (with heads)
# TF models always have a prefix, some of PyTorch models (base ones) don't
start_prefix_to_remove = ''
if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()):
start_prefix_to_remove = tf_model.base_model_prefix + '.'
symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights
weight_value_tuples = []
@ -100,13 +106,23 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None
if name[-1] == 'beta':
name[-1] = 'bias'
# Remove prefix if needed
name = '.'.join(name)
if start_prefix_to_remove:
name = name.replace(start_prefix_to_remove, '', 1)
# Find associated numpy array in pytorch model state dict
assert name in pt_state_dict, "{} not found in PyTorch model".format(name)
array = pt_state_dict[name].numpy()
if transpose:
array = numpy.transpose(array)
if len(symbolic_weight.shape) < len(array.shape):
array = numpy.squeeze(array)
elif len(symbolic_weight.shape) > len(array.shape):
array = numpy.expand_dims(array, axis=0)
try:
assert list(symbolic_weight.shape) == list(array.shape)
except AssertionError as e: