Fix confusing warnings during TF2 import from PyTorch (#6623)

1. Swapped missing_keys and unexpected_keys.

2. Copy&paste error caused these warnings to say "from TF 2.0" when it's actually "from PyTorch".
This commit is contained in:
Johann C. Rocholl 2020-09-10 02:31:59 -07:00 committed by GitHub
parent 4ee1053dcf
commit 49e9be0639
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -148,7 +148,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
tf_loaded_numel = 0 tf_loaded_numel = 0
weight_value_tuples = [] weight_value_tuples = []
all_pytorch_weights = set(list(pt_state_dict.keys())) all_pytorch_weights = set(list(pt_state_dict.keys()))
unexpected_keys = [] missing_keys = []
for symbolic_weight in symbolic_weights: for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name( name, transpose = convert_tf_weight_name_to_pt_weight_name(
@ -158,7 +158,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
# Find associated numpy array in pytorch model state dict # Find associated numpy array in pytorch model state dict
if name not in pt_state_dict: if name not in pt_state_dict:
if allow_missing_keys: if allow_missing_keys:
unexpected_keys.append(name) missing_keys.append(name)
continue continue
raise AttributeError("{} not found in PyTorch model".format(name)) raise AttributeError("{} not found in PyTorch model".format(name))
@ -192,28 +192,28 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel)) logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))
missing_keys = list(all_pytorch_weights) unexpected_keys = list(all_pytorch_weights)
if len(unexpected_keys) > 0: if len(unexpected_keys) > 0:
logger.warning( logger.warning(
f"Some weights of the PyTorch model were not used when " f"Some weights of the PyTorch model were not used when "
f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n" f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model trained on another task " f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n" f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPretraining model).\n"
f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model that you expect " f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect "
f"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)." f"to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model)."
) )
else: else:
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n") logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
if len(missing_keys) > 0: if len(missing_keys) > 0:
logger.warning( logger.warning(
f"Some weights or buffers of the PyTorch model {tf_model.__class__.__name__} were not initialized from the TF 2.0 model " f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the PyTorch model "
f"and are newly initialized: {missing_keys}\n" f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference." f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
) )
else: else:
logger.warning( logger.warning(
f"All the weights of {tf_model.__class__.__name__} were initialized from the TF 2.0 model.\n" f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, " f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"you can already use {tf_model.__class__.__name__} for predictions without further training." f"you can already use {tf_model.__class__.__name__} for predictions without further training."
) )