From 6a083fd4478be2f5163c08c93edd6b65dcbd0a98 Mon Sep 17 00:00:00 2001 From: thomwolf Date: Wed, 18 Sep 2019 12:11:32 +0200 Subject: [PATCH] update pt-tf conversion script --- .../convert_pytorch_checkpoint_to_tf2.py | 12 ++++++++---- .../modeling_tf_pytorch_utils.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py b/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py index ca6089bd044..25ed70c2db0 100644 --- a/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py +++ b/pytorch_transformers/convert_pytorch_checkpoint_to_tf2.py @@ -102,7 +102,7 @@ def convert_pt_checkpoint_to_tf(model_type, pytorch_checkpoint_path, config_file tf_model.save_weights(tf_dump_path, save_format='h5') -def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False): +def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with_pt_model=False, use_cached_models=False): assert os.path.isdir(args.tf_dump_path), "--tf_dump_path should be a directory" if args_model_type is None: @@ -126,8 +126,8 @@ def convert_all_pt_checkpoints_to_tf(args_model_type, tf_dump_path, compare_with if 'finetuned' in shortcut_name: print(" Skipping finetuned checkpoint ") continue - config_file = cached_path(aws_config_map[shortcut_name], force_download=True) - model_file = cached_path(aws_model_maps[shortcut_name], force_download=True) + config_file = cached_path(aws_config_map[shortcut_name], force_download=not use_cached_models) + model_file = cached_path(aws_model_maps[shortcut_name], force_download=not use_cached_models) convert_pt_checkpoint_to_tf(model_type, model_file, @@ -165,6 +165,9 @@ if __name__ == "__main__": parser.add_argument("--compare_with_pt_model", action='store_true', help = "Compare Tensorflow and PyTorch model predictions.") + parser.add_argument("--use_cached_models", + action='store_true', + help = "Use cached models if possible instead of updating to latest checkpoint versions.") args = parser.parse_args() if args.pytorch_checkpoint_path is not None: @@ -176,4 +179,5 @@ if __name__ == "__main__": else: convert_all_pt_checkpoints_to_tf(args.model_type.lower() if args.model_type is not None else None, args.tf_dump_path, - compare_with_pt_model=args.compare_with_pt_model) + compare_with_pt_model=args.compare_with_pt_model, + use_cached_models=args.use_cached_models) diff --git a/pytorch_transformers/modeling_tf_pytorch_utils.py b/pytorch_transformers/modeling_tf_pytorch_utils.py index 2420b69fb93..9950a5a73fa 100644 --- a/pytorch_transformers/modeling_tf_pytorch_utils.py +++ b/pytorch_transformers/modeling_tf_pytorch_utils.py @@ -78,6 +78,12 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None for old_key, new_key in zip(old_keys, new_keys): pt_state_dict[new_key] = pt_state_dict.pop(old_key) + # Make sure we are able to load PyTorch base models as well as derived models (with heads) + # TF models always have a prefix, some of PyTorch models (base ones) don't + start_prefix_to_remove = '' + if not any(s.startswith(tf_model.base_model_prefix) for s in pt_state_dict.keys()): + start_prefix_to_remove = tf_model.base_model_prefix + '.' + symbolic_weights = tf_model.trainable_weights + tf_model.non_trainable_weights weight_value_tuples = [] @@ -100,13 +106,23 @@ def load_pytorch_state_dict_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None if name[-1] == 'beta': name[-1] = 'bias' + # Remove prefix if needed name = '.'.join(name) + if start_prefix_to_remove: + name = name.replace(start_prefix_to_remove, '', 1) + + # Find associated numpy array in pytorch model state dict assert name in pt_state_dict, "{} not found in PyTorch model".format(name) array = pt_state_dict[name].numpy() if transpose: array = numpy.transpose(array) + if len(symbolic_weight.shape) < len(array.shape): + array = numpy.squeeze(array) + elif len(symbolic_weight.shape) > len(array.shape): + array = numpy.expand_dims(array, axis=0) + try: assert list(symbolic_weight.shape) == list(array.shape) except AssertionError as e: