mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
Fix confusing warnings during TF2 import from PyTorch (#6623)
1. Swapped missing_keys and unexpected_keys. 2. Copy&paste error caused these warnings to say "from TF 2.0" when it's actually "from PyTorch".
This commit is contained in:
parent
4ee1053dcf
commit
49e9be0639
@ -148,7 +148,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
tf_loaded_numel = 0
|
tf_loaded_numel = 0
|
||||||
weight_value_tuples = []
|
weight_value_tuples = []
|
||||||
all_pytorch_weights = set(list(pt_state_dict.keys()))
|
all_pytorch_weights = set(list(pt_state_dict.keys()))
|
||||||
unexpected_keys = []
|
missing_keys = []
|
||||||
for symbolic_weight in symbolic_weights:
|
for symbolic_weight in symbolic_weights:
|
||||||
sw_name = symbolic_weight.name
|
sw_name = symbolic_weight.name
|
||||||
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
name, transpose = convert_tf_weight_name_to_pt_weight_name(
|
||||||
@ -158,7 +158,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
# Find associated numpy array in pytorch model state dict
|
# Find associated numpy array in pytorch model state dict
|
||||||
if name not in pt_state_dict:
|
if name not in pt_state_dict:
|
||||||
if allow_missing_keys:
|
if allow_missing_keys:
|
||||||
unexpected_keys.append(name)
|
missing_keys.append(name)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
raise AttributeError("{} not found in PyTorch model".format(name))
|
raise AttributeError("{} not found in PyTorch model".format(name))
|
||||||
@ -192,28 +192,28 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
|||||||
|
|
||||||
logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))
|
logger.info("Loaded {:,} parameters in the TF 2.0 model.".format(tf_loaded_numel))
|
||||||
|
|
||||||
missing_keys = list(all_pytorch_weights)
|
unexpected_keys = list(all_pytorch_weights)
|
||||||
|
|
||||||
if len(unexpected_keys) > 0:
|
if len(unexpected_keys) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Some weights of the PyTorch model were not used when "
|
f"Some weights of the PyTorch model were not used when "
|
||||||
f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n"
|
f"initializing the TF 2.0 model {tf_model.__class__.__name__}: {unexpected_keys}\n"
|
||||||
f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model trained on another task "
|
f"- This IS expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model trained on another task "
|
||||||
f"or with another architecture (e.g. initializing a BertForSequenceClassification model from a TFBertForPretraining model).\n"
|
f"or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPretraining model).\n"
|
||||||
f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a TF 2.0 model that you expect "
|
f"- This IS NOT expected if you are initializing {tf_model.__class__.__name__} from a PyTorch model that you expect "
|
||||||
f"to be exactly identical (e.g. initializing a BertForSequenceClassification model from a TFBertForSequenceClassification model)."
|
f"to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
|
logger.warning(f"All PyTorch model weights were used when initializing {tf_model.__class__.__name__}.\n")
|
||||||
if len(missing_keys) > 0:
|
if len(missing_keys) > 0:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Some weights or buffers of the PyTorch model {tf_model.__class__.__name__} were not initialized from the TF 2.0 model "
|
f"Some weights or buffers of the TF 2.0 model {tf_model.__class__.__name__} were not initialized from the PyTorch model "
|
||||||
f"and are newly initialized: {missing_keys}\n"
|
f"and are newly initialized: {missing_keys}\n"
|
||||||
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"All the weights of {tf_model.__class__.__name__} were initialized from the TF 2.0 model.\n"
|
f"All the weights of {tf_model.__class__.__name__} were initialized from the PyTorch model.\n"
|
||||||
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
f"If your task is similar to the task the model of the ckeckpoint was trained on, "
|
||||||
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
f"you can already use {tf_model.__class__.__name__} for predictions without further training."
|
||||||
)
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user