Cleaner warning when loading pretrained models (#4557)

* Cleaner warning when loading pretrained models

This make more explicit logging messages when using the various `from_pretrained` methods. It also make these messages as `logging.warning` because it's a common source of silent mistakes.

* Update src/transformers/modeling_utils.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* Update src/transformers/modeling_utils.py

Co-authored-by: Julien Chaumond <chaumond@gmail.com>

* style and quality

Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
Thomas Wolf 2020-06-22 21:58:47 +02:00 committed by GitHub
parent 4e741efa92
commit 75e1eed8d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 89 additions and 22 deletions

View File

@ -150,6 +150,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
tf_loaded_numel = 0
weight_value_tuples = []
all_pytorch_weights = set(list(pt_state_dict.keys()))
unexpected_keys = []
for symbolic_weight in symbolic_weights:
sw_name = symbolic_weight.name
name, transpose = convert_tf_weight_name_to_pt_weight_name(
@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
# Find associated numpy array in pytorch model state dict
if name not in pt_state_dict:
if allow_missing_keys:
unexpected_keys.append(name)
continue
raise AttributeError("{} not found in PyTorch model".format(name))
@ -192,7 +194,31 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
missing_keys = list(all_pytorch_weights)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the PyTorch model were not used when "
f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n"
f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)."
)
else:
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights or buffers of the PyTorch model {tf_model.__class__.__name__} were not initialized from the TF 2.0 model "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {tf_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
)
return tf_model
@ -317,13 +343,28 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
missing_keys += missing_keys_pt
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from TF 2.0 model: {}".format(pt_model.__class__.__name__, missing_keys)
)
if len(unexpected_keys) > 0:
logger.info(
"Weights from TF 2.0 model not used in {}: {}".format(pt_model.__class__.__name__, unexpected_keys)
logger.warning(
f"Some weights of the TF 2.0 model were not used when "
f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n"
f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)."
)
else:
logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
)
logger.info("Weights or buffers not loaded from TF 2.0 model: {}".format(all_tf_weights))

View File

@ -504,13 +504,28 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
unexpected_keys = list(hdf5_layer_names - model_layer_names)
error_msgs = []
if len(missing_keys) > 0:
logger.info(
"Layers of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
)
if len(unexpected_keys) > 0:
logger.info(
"Layers from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.warning(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(error_msgs) > 0:
raise RuntimeError(

View File

@ -750,17 +750,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
if len(missing_keys) > 0:
logger.info(
"Weights of {} not initialized from pretrained model: {}".format(
model.__class__.__name__, missing_keys
)
)
if len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
)
else:
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
if len(missing_keys) > 0:
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
logger.info(
"Weights from pretrained model not used in {}: {}".format(
model.__class__.__name__, unexpected_keys
)
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(error_msgs) > 0:
raise RuntimeError(