From 49e9be063938bb6be48e3a82881f68c20f1521db Mon Sep 17 00:00:00 2001 From: "Johann C. Rocholl" Date: Thu, 10 Sep 2020 02:31:59 -0700 Subject: [PATCH] Fix confusing warnings during TF2 import from PyTorch (#6623) 1. Swapped missing_keys and unexpected_keys. 2. Copy&paste error caused these warnings to say "from TF 2.0" when it's actually "from PyTorch". --- src/transformers/modeling_tf_pytorch_utils.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 5600f8d6630..0316f49ad73 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -148,7 +148,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a tf_loaded_numel = 0 weight_value_tuples = [] all_pytorch_weights = set(list(pt_state_dict.keys())) - unexpected_keys = [] + missing_keys = [] for symbolic_weight in symbolic_weights: sw_name = symbolic_weight.name name, transpose = convert_tf_weight_name_to_pt_weight_name( @@ -158,7 +158,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a # Find associated numpy array in pytorch model state dict if name not in pt_state_dict: if allow_missing_keys: - unexpected_keys.append(name) + missing_keys.append(name) continue raise AttributeError("{} not found in PyTorch model".format(name)) @@ -192,28 +192,28 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel)) - missing_keys = list(all_pytorch_weights) + unexpected_keys = list(all_pytorch_weights) if len(unexpected_keys) > 0: logger.warning( f"Some weights of the PyTorch model were not used when " f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n" - f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model trained on another task " - f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n" - f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model that you expect " - f"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)." + f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model trained on another task " + f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPretraining model).\n" + f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect " + f"to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model)." ) else: logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n") if len(missing_keys) > 0: logger.warning( - f"Some weights or buffers of the PyTorch model {tf_model.__class__.__name__} were not initialized from the TF 2.0 model " + f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the PyTorch model " f"and are newly initialized: {missing_keys}\n" f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." ) else: logger.warning( - f"All the weights of {tf_model.__class__.__name__} were initialized from the TF 2.0 model.\n" + f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n" f"If your task is similar to the task the model of the ckeckpoint was trained on, " f"you can already use {tf_model.__class__.__name__} for predictions without further training." )