Add option to load a pretrained model with mismatched shapes (#12664)

* Add option to load a pretrained model with mismatched shapes

* Fail at loading when mismatched shapes in Flax

* Fix tests

* Update src/transformers/modeling_flax_utils.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* Address review comments

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sylvain Gugger 2021-07-13 10:15:15 -04:00 committed by GitHub
parent 7f6d375029
commit 90178b0cef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 228 additions and 67 deletions

View File

@ -199,6 +199,10 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
from_pt (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@ -242,6 +246,7 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
from_pt = kwargs.pop("from_pt", False)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
@ -367,6 +372,22 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
missing_keys = model.required_params - set(state.keys())
unexpected_keys = set(state.keys()) - model.required_params
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
for key in state.keys():
if key in random_state and state[key].shape != random_state[key].shape:
if ignore_mismatched_sizes:
mismatched_keys.append((key, state[key].shape, random_state[key].shape))
state[key] = random_state[key]
else:
raise ValueError(
f"Trying to load the pretrained weight for {key} failed: checkpoint has shape "
f"{state[key].shape} which is incompatible with the model shape {random_state[key].shape}. "
"Using `ignore_mismatched_sizes=True` if you really want to load this checkpoint inside this "
"model."
)
# add missing keys as random parameters
for missing_key in missing_keys:
state[missing_key] = random_state[missing_key]
@ -393,12 +414,24 @@ class FlaxPreTrainedModel(PushToHubMixin, FlaxGenerationMixin):
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
# set correct parameters
model.params = unflatten_dict(state)

View File

@ -450,7 +450,7 @@ def input_processing(func, config, input_ids, **kwargs):
return output
def load_tf_weights(model, resolved_archive_file, _prefix=None):
def load_tf_weights(model, resolved_archive_file, ignore_mismatched_sizes=False, _prefix=None):
"""
Detect missing and unexpected layers and load the TF weights accordingly to their names and shapes.
@ -459,12 +459,16 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None):
The model to load the weights into.
resolved_archive_file (:obj:`str`):
The location of the H5 file.
ignore_mismatched_sizes (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to ignore weights with shapes that don't match between the checkpoint of the model.
Returns:
Two lists, one for the missing layers, and another one for the unexpected layers.
Three lists, one for the missing layers, another one for the unexpected layers, and a last one for the
mismatched layers.
"""
missing_layers = []
unexpected_layers = []
mismatched_layers = []
# Read the H5 file
with h5py.File(resolved_archive_file, "r") as f:
@ -533,9 +537,14 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None):
# If the two shapes are not compatible we raise an issue
try:
array = np.reshape(saved_weight_value, K.int_shape(symbolic_weight))
except AssertionError as e:
e.args += (K.int_shape(symbolic_weight), saved_weight_value.shape)
raise e
except ValueError as e:
if ignore_mismatched_sizes:
mismatched_layers.append(
(symbolic_weight_name, saved_weight_value.shape, K.int_shape(symbolic_weight))
)
continue
else:
raise e
else:
array = saved_weight_value
@ -549,7 +558,7 @@ def load_tf_weights(model, resolved_archive_file, _prefix=None):
missing_layers.extend(list(symbolic_weights_names - saved_weight_names_set))
unexpected_layers.extend(list(saved_weight_names_set - symbolic_weights_names))
return missing_layers, unexpected_layers
return missing_layers, unexpected_layers, mismatched_layers
def init_copy_embeddings(old_embeddings, new_num_tokens):
@ -1123,6 +1132,10 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
from_pt: (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a PyTorch state_dict save file (see docstring of
``pretrained_model_name_or_path`` argument).
ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
cache_dir (:obj:`str`, `optional`):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
@ -1186,6 +1199,7 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
config = kwargs.pop("config", None)
cache_dir = kwargs.pop("cache_dir", None)
from_pt = kwargs.pop("from_pt", False)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
@ -1307,7 +1321,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
# 'by_name' allow us to do transfer learning by skipping/adding layers
# see https://github.com/tensorflow/tensorflow/blob/00fad90125b18b80fe054de1055770cfb8fe4ba3/tensorflow/python/keras/engine/network.py#L1339-L1357
try:
missing_keys, unexpected_keys = load_tf_weights(model, resolved_archive_file, load_weight_prefix)
missing_keys, unexpected_keys, mismatched_keys = load_tf_weights(
model,
resolved_archive_file,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_prefix=load_weight_prefix,
)
except OSError:
raise OSError(
"Unable to load weights from h5 file. "
@ -1342,15 +1361,31 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin, Pu
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
elif len(mismatched_keys) == 0:
logger.warning(
f"All the layers of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
if output_loading_info:
loading_info = {"missing_keys": missing_keys, "unexpected_keys": unexpected_keys}
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
}
return model, loading_info

View File

@ -1037,6 +1037,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
from_flax (:obj:`bool`, `optional`, defaults to :obj:`False`):
Load the model weights from a Flax checkpoint save file (see docstring of
``pretrained_model_name_or_path`` argument).
ignore_mismatched_size (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
checkpoint with 3 labels).
force_download (:obj:`bool`, `optional`, defaults to :obj:`False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
@ -1120,6 +1124,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
cache_dir = kwargs.pop("cache_dir", None)
from_tf = kwargs.pop("from_tf", False)
from_flax = kwargs.pop("from_flax", False)
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
@ -1315,8 +1320,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
)
raise
elif from_pt:
model, missing_keys, unexpected_keys, error_msgs = cls._load_state_dict_into_model(
model, state_dict, pretrained_model_name_or_path, _fast_init=_fast_init
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_state_dict_into_model(
model,
state_dict,
pretrained_model_name_or_path,
ignore_mismatched_sizes=ignore_mismatched_sizes,
_fast_init=_fast_init,
)
# make sure token embedding weights are still tied if needed
@ -1329,6 +1338,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
loading_info = {
"missing_keys": missing_keys,
"unexpected_keys": unexpected_keys,
"mismatched_keys": mismatched_keys,
"error_msgs": error_msgs,
}
return model, loading_info
@ -1336,7 +1346,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
return model
@classmethod
def _load_state_dict_into_model(cls, model, state_dict, pretrained_model_name_or_path, _fast_init=True):
def _load_state_dict_into_model(
cls, model, state_dict, pretrained_model_name_or_path, ignore_mismatched_sizes=False, _fast_init=True
):
# Convert old format to new format if needed from a PyTorch state_dict
old_keys = []
@ -1354,7 +1366,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
state_dict[new_key] = state_dict.pop(old_key)
# Retrieve missing & unexpected_keys
expected_keys = list(model.state_dict().keys())
model_state_dict = model.state_dict()
expected_keys = list(model_state_dict.keys())
loaded_keys = list(state_dict.keys())
prefix = model.base_model_prefix
@ -1374,6 +1387,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
missing_keys = list(set(expected_keys) - set(loaded_keys))
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
# Mistmatched keys contains tuples key/shape1/shape2 of weights in the checkpoint that have a shape not
# matching the weights in the model.
mismatched_keys = []
if ignore_mismatched_sizes:
for checkpoint_key in loaded_keys:
model_key = checkpoint_key
if remove_prefix and checkpoint_key.startswith(prefix):
model_key = ".".join(checkpoint_key.split(".")[1:])
elif add_prefix:
model_key = f"{prefix}.{checkpoint_key}"
if (
model_key in model_state_dict
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
):
mismatched_keys.append(
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
)
del state_dict[checkpoint_key]
# Some models may have keys that are not in the state by design, removing them before needlessly warning
# the user.
if cls._keys_to_ignore_on_load_missing is not None:
@ -1452,14 +1485,26 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
f"and are newly initialized: {missing_keys}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
else:
elif len(mismatched_keys) == 0:
logger.info(
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at {pretrained_model_name_or_path}.\n"
f"If your task is similar to the task the model of the checkpoint was trained on, "
f"you can already use {model.__class__.__name__} for predictions without further training."
)
if len(mismatched_keys) > 0:
mismatched_warning = "\n".join(
[
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
for key, shape1, shape2 in mismatched_keys
]
)
logger.warning(
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at {pretrained_model_name_or_path} "
f"and are newly initialized because the shapes did not match:\n{mismatched_warning}\n"
f"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference."
)
return model, missing_keys, unexpected_keys, error_msgs
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
module_keys = set([".".join(key.split(".")[:-1]) for key in names])

View File

@ -186,7 +186,7 @@ CONFIG_MAPPING = OrderedDict(
("pegasus", PegasusConfig),
("marian", MarianConfig),
("mbart", MBartConfig),
("megatron_bert", MegatronBertConfig),
("megatron-bert", MegatronBertConfig),
("mpnet", MPNetConfig),
("bart", BartConfig),
("blenderbot", BlenderbotConfig),
@ -252,7 +252,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("blenderbot", "Blenderbot"),
("marian", "Marian"),
("mbart", "mBART"),
("megatron_bert", "MegatronBert"),
("megatron-bert", "MegatronBert"),
("bart", "BART"),
("reformer", "Reformer"),
("longformer", "Longformer"),

View File

@ -760,10 +760,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = ["position_ids"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
def __init__(self, config):
super().__init__(config)
self._register_load_state_dict_pre_hook(self._pre_load_hook)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
@ -777,25 +773,6 @@ class DebertaPreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
"""
Removes the classifier if it doesn't have the correct number of labels.
"""
self_state = self.state_dict()
if (
("classifier.weight" in self_state)
and ("classifier.weight" in state_dict)
and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
):
logger.warning(
f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
f"weights. You should train your model on new data."
)
del state_dict["classifier.weight"]
if "classifier.bias" in state_dict:
del state_dict["classifier.bias"]
DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention

View File

@ -881,10 +881,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
_keys_to_ignore_on_load_missing = ["position_ids"]
_keys_to_ignore_on_load_unexpected = ["position_embeddings"]
def __init__(self, config):
super().__init__(config)
self._register_load_state_dict_pre_hook(self._pre_load_hook)
def _init_weights(self, module):
"""Initialize the weights."""
if isinstance(module, nn.Linear):
@ -898,25 +894,6 @@ class DebertaV2PreTrainedModel(PreTrainedModel):
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
"""
Removes the classifier if it doesn't have the correct number of labels.
"""
self_state = self.state_dict()
if (
("classifier.weight" in self_state)
and ("classifier.weight" in state_dict)
and self_state["classifier.weight"].size() != state_dict["classifier.weight"].size()
):
logger.warning(
f"The checkpoint classifier head has a shape {state_dict['classifier.weight'].size()} and this model "
f"classifier head has a shape {self_state['classifier.weight'].size()}. Ignoring the checkpoint "
f"weights. You should train your model on new data."
)
del state_dict["classifier.weight"]
if "classifier.bias" in state_dict:
del state_dict["classifier.bias"]
DEBERTA_START_DOCSTRING = r"""
The DeBERTa model was proposed in `DeBERTa: Decoding-enhanced BERT with Disentangled Attention

View File

@ -25,7 +25,7 @@ from typing import Dict, List, Tuple
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import AutoModel, is_torch_available, logging
from transformers import AutoModel, AutoModelForSequenceClassification, is_torch_available, logging
from transformers.file_utils import WEIGHTS_NAME, is_torch_fx_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
@ -1532,6 +1532,35 @@ class ModelTesterMixin:
loss.backward()
def test_load_with_mismatched_shapes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class not in get_values(MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
continue
with self.subTest(msg=f"Testing {model_class}"):
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)
# Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(RuntimeError) as e:
print(type(e))
new_model = AutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
logger = logging.get_logger("transformers.modeling_utils")
with CaptureLogger(logger) as cl:
new_model = AutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
new_model.to(torch_device)
inputs = self._prepare_for_class(inputs_dict, model_class)
logits = new_model(**inputs).logits
self.assertEqual(logits.shape[1], 42)
global_rng = random.Random()

View File

@ -24,17 +24,19 @@ import numpy as np
import transformers
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import BertConfig, FlaxBertModel, is_flax_available, is_torch_available
from transformers import BertConfig, is_flax_available, is_torch_available
from transformers.models.auto import get_values
from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
is_pt_flax_cross_test,
is_staging_test,
require_flax,
slow,
)
from transformers.utils import logging
if is_flax_available():
@ -45,7 +47,13 @@ if is_flax_available():
import jaxlib.xla_extension as jax_xla
from flax.core.frozen_dict import unfreeze
from flax.traverse_util import flatten_dict
from transformers import FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING, FLAX_MODEL_MAPPING
from transformers import (
FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING,
FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
FLAX_MODEL_MAPPING,
FlaxAutoModelForSequenceClassification,
FlaxBertModel,
)
from transformers.modeling_flax_pytorch_utils import (
convert_pytorch_state_dict_to_flax,
load_flax_weights_in_pytorch_model,
@ -516,6 +524,32 @@ class FlaxModelTesterMixin:
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length],
)
def test_load_with_mismatched_shapes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class not in get_values(FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
continue
with self.subTest(msg=f"Testing {model_class}"):
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
model.save_pretrained(tmp_dir)
# Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(ValueError):
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
logger = logging.get_logger("transformers.modeling_flax_utils")
with CaptureLogger(logger) as cl:
new_model = FlaxAutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
logits = new_model(**inputs_dict)["logits"]
self.assertEqual(logits.shape[1], 42)
@require_flax
@is_staging_test

View File

@ -32,6 +32,7 @@ from transformers.testing_utils import (
ENDPOINT_STAGING,
PASS,
USER,
CaptureLogger,
_tf_gpu_memory_limit,
is_pt_tf_cross_test,
is_staging_test,
@ -40,6 +41,7 @@ from transformers.testing_utils import (
slow,
tooslow,
)
from transformers.utils import logging
if is_tf_available():
@ -57,6 +59,7 @@ if is_tf_available():
TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING,
TF_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING,
BertConfig,
TFAutoModelForSequenceClassification,
TFBertModel,
TFSharedEmbeddings,
tf_top_k_top_p_filtering,
@ -1308,6 +1311,34 @@ class TFModelTesterMixin:
attn_weights = out[attn_name] if attn_name == attention_names[0] else out[attn_name][-1]
self.assertEqual(sum([tf.reduce_sum(w).numpy() for w in attn_weights]), 0.0)
def test_load_with_mismatched_shapes(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
for model_class in self.all_model_classes:
if model_class not in get_values(TF_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING):
continue
with self.subTest(msg=f"Testing {model_class}"):
with tempfile.TemporaryDirectory() as tmp_dir:
model = model_class(config)
inputs = self._prepare_for_class(inputs_dict, model_class)
_ = model(**inputs)
model.save_pretrained(tmp_dir)
# Fails when we don't set ignore_mismatched_sizes=True
with self.assertRaises(ValueError):
new_model = TFAutoModelForSequenceClassification.from_pretrained(tmp_dir, num_labels=42)
logger = logging.get_logger("transformers.modeling_tf_utils")
with CaptureLogger(logger) as cl:
new_model = TFAutoModelForSequenceClassification.from_pretrained(
tmp_dir, num_labels=42, ignore_mismatched_sizes=True
)
self.assertIn("the shapes did not match", cl.out)
logits = new_model(**inputs).logits
self.assertEqual(logits.shape[1], 42)
def _generate_random_bad_tokens(self, num_bad_tokens, model):
# special tokens cannot be bad tokens
special_tokens = []