mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
Cleaner warning when loading pretrained models (#4557)
* Cleaner warning when loading pretrained models This make more explicit logging messages when using the various `from_pretrained` methods. It also make these messages as `logging.warning` because it's a common source of silent mistakes. * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> * Update src/transformers/modeling_utils.py Co-authored-by: Julien Chaumond <chaumond@gmail.com> * style and quality Co-authored-by: Julien Chaumond <chaumond@gmail.com>
This commit is contained in:
parent
4e741efa92
commit
75e1eed8d1
@ -150,6 +150,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
tf_loaded_numel = 0
|
||||
weight_value_tuples = []
|
||||
all_pytorch_weights = set(list(pt_state_dict.keys()))
|
||||
unexpected_keys = []
|
||||
for symbolic_weight in symbolic_weights:
|
||||
sw_name = symbolic_weight.name
|
||||
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||
@ -159,6 +160,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
# Find associated numpy array in pytorch model state dict
|
||||
if name not in pt_state_dict:
|
||||
if allow_missing_keys:
|
||||
unexpected_keys.append(name)
|
||||
continue
|
||||
|
||||
raise AttributeError("{} not found in PyTorch model".format(name))
|
||||
@ -192,7 +194,31 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
|
||||
logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))
|
||||
|
||||
logger.info("Weights or buffers not loaded from PyTorch model: {}".format(all_pytorch_weights))
|
||||
missing_keys = list(all_pytorch_weights)
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the PyTorch model were not used when "
|
||||
f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights or buffers of the PyTorch model {tf_model.__class__.__name__} were not initialized from the TF 2.0 model "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {tf_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
|
||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
||||
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
|
||||
return tf_model
|
||||
|
||||
@ -317,13 +343,28 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
|
||||
missing_keys, unexpected_keys = pt_model.load_state_dict(new_pt_params_dict, strict=False)
|
||||
missing_keys += missing_keys_pt
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.info(
|
||||
"Weights of {} not initialized from TF 2.0 model: {}".format(pt_model.__class__.__name__, missing_keys)
|
||||
)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info(
|
||||
"Weights from TF 2.0 model not used in {}: {}".format(pt_model.__class__.__name__, unexpected_keys)
|
||||
logger.warning(
|
||||
f"Some weights of the TF 2.0 model were not used when "
|
||||
f"initializing the PyTorch model {pt_model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {pt_model.__class__.__name__} from a TF 2.0 model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"All TF 2.0 model weights were used when initializing {pt_model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {pt_model.__class__.__name__} were not initialized from the TF 2.0 model "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {pt_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
|
||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
||||
f"you can already use {pt_model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
|
||||
logger.info("Weights or buffers not loaded from TF 2.0 model: {}".format(all_tf_weights))
|
||||
|
@ -504,13 +504,28 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin):
|
||||
unexpected_keys = list(hdf5_layer_names - model_layer_names)
|
||||
error_msgs = []
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.info(
|
||||
"Layers of {} not initialized from pretrained model: {}".format(model.__class__.__name__, missing_keys)
|
||||
)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.info(
|
||||
"Layers from pretrained model not used in {}: {}".format(model.__class__.__name__, unexpected_keys)
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.warning(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
|
@ -750,17 +750,28 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin):
|
||||
|
||||
missing_keys.extend(head_model_state_dict_without_base_prefix - base_model_state_dict)
|
||||
|
||||
if len(missing_keys) > 0:
|
||||
logger.info(
|
||||
"Weights of {} not initialized from pretrained model: {}".format(
|
||||
model.__class__.__name__, missing_keys
|
||||
)
|
||||
)
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when "
|
||||
f"initializing {model.__class__.__name__}: {unexpected_keys}\n"
|
||||
f"- This IS expected if you are initializing {model.__class__.__name__} from the checkpoint of a model trained on another task "
|
||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).\n"
|
||||
f"- This IS NOT expected if you are initializing {model.__class__.__name__} from the checkpoint of a model that you expect "
|
||||
f"to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
|
||||
f"and are newly initialized: {missing_keys}\n"
|
||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Weights from pretrained model not used in {}: {}".format(
|
||||
model.__class__.__name__, unexpected_keys
|
||||
)
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
|
||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
||||
f"you can already use {model.__class__.__name__} for predictions without further training."
|
||||
)
|
||||
if len(error_msgs) > 0:
|
||||
raise RuntimeError(
|
||||
|
Loading…
Reference in New Issue
Block a user