mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
consistent ignore keys + make private (#8737)
* consistent ignore keys + make private * style * - authorized_missing_keys => _keys_to_ignore_on_load_missing - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected * move public doc of private attributes to private comment
This commit is contained in:
parent
49759c0cda
commit
e84786aaa6
@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
if allow_missing_keys:
|
||||
missing_keys.append(name)
|
||||
continue
|
||||
elif tf_model.authorized_missing_keys is not None:
|
||||
elif tf_model._keys_to_ignore_on_load_missing is not None:
|
||||
# authorized missing keys don't have to be loaded
|
||||
if any(re.search(pat, name) is not None for pat in tf_model.authorized_missing_keys):
|
||||
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
|
||||
continue
|
||||
|
||||
raise AttributeError("{} not found in PyTorch model".format(name))
|
||||
@ -209,11 +209,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
|
||||
|
||||
unexpected_keys = list(all_pytorch_weights)
|
||||
|
||||
if tf_model.authorized_missing_keys is not None:
|
||||
for pat in tf_model.authorized_missing_keys:
|
||||
if tf_model._keys_to_ignore_on_load_missing is not None:
|
||||
for pat in tf_model._keys_to_ignore_on_load_missing:
|
||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||
if tf_model.authorized_unexpected_keys is not None:
|
||||
for pat in tf_model.authorized_unexpected_keys:
|
||||
if tf_model._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in tf_model._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
|
@ -343,15 +343,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
|
||||
from the model when loading the model weights (and avoid unnecessary warnings).
|
||||
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
|
||||
ignore from the weights when loading the model weights (and avoid unnecessary warnings).
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
authorized_missing_keys = None
|
||||
authorized_unexpected_keys = None
|
||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_missing = None
|
||||
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_unexpected = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
|
||||
@ -742,12 +742,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
|
||||
|
||||
model(model.dummy_inputs, training=False) # Make sure restore ops are run
|
||||
|
||||
if cls.authorized_missing_keys is not None:
|
||||
for pat in cls.authorized_missing_keys:
|
||||
if cls._keys_to_ignore_on_load_missing is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_missing:
|
||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||
|
||||
if cls.authorized_unexpected_keys is not None:
|
||||
for pat in cls.authorized_unexpected_keys:
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
|
@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
|
||||
derived classes of the same architecture adding modules on top of the base model.
|
||||
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
|
||||
when loading the model (and avoid unnecessary warnings).
|
||||
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the
|
||||
model (useful for keys that aren't trained, but which are deterministic)
|
||||
|
||||
"""
|
||||
config_class = None
|
||||
base_model_prefix = ""
|
||||
authorized_missing_keys = None
|
||||
authorized_unexpected_keys = None
|
||||
keys_to_never_save = None
|
||||
# a list of re pattern of tensor names to ignore from the model when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_missing = None
|
||||
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
|
||||
# (and avoid unnecessary warnings).
|
||||
_keys_to_ignore_on_load_unexpected = None
|
||||
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
|
||||
# trained, but which are deterministic)
|
||||
_keys_to_ignore_on_save = None
|
||||
|
||||
@property
|
||||
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
|
||||
@ -719,8 +720,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Handle the case where some state_dict keys shouldn't be saved
|
||||
if self.keys_to_never_save is not None:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save}
|
||||
if self._keys_to_ignore_on_save is not None:
|
||||
state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
|
||||
@ -1034,12 +1035,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
|
||||
# Some models may have keys that are not in the state by design, removing them before needlessly warning
|
||||
# the user.
|
||||
if cls.authorized_missing_keys is not None:
|
||||
for pat in cls.authorized_missing_keys:
|
||||
if cls._keys_to_ignore_on_load_missing is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_missing:
|
||||
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
|
||||
|
||||
if cls.authorized_unexpected_keys is not None:
|
||||
for pat in cls.authorized_unexpected_keys:
|
||||
if cls._keys_to_ignore_on_load_unexpected is not None:
|
||||
for pat in cls._keys_to_ignore_on_load_unexpected:
|
||||
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
|
@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = AlbertConfig
|
||||
base_model_prefix = "albert"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
@ -851,7 +851,7 @@ class AlbertSOPHead(nn.Module):
|
||||
)
|
||||
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1021,7 +1021,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel):
|
||||
)
|
||||
class AlbertForTokenClassification(AlbertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1110,7 +1110,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel):
|
||||
)
|
||||
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -843,7 +843,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer):
|
||||
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
|
||||
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1013,7 +1013,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass
|
||||
)
|
||||
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1100,7 +1100,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat
|
||||
)
|
||||
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
@ -946,7 +946,7 @@ class BartModel(PretrainedBartModel):
|
||||
)
|
||||
class BartForConditionalGeneration(PretrainedBartModel):
|
||||
base_model_prefix = "model"
|
||||
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
|
||||
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
|
@ -1020,10 +1020,10 @@ class TFBartModel(TFPretrainedBartModel):
|
||||
)
|
||||
class TFBartForConditionalGeneration(TFPretrainedBartModel):
|
||||
base_model_prefix = "model"
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
]
|
||||
authorized_unexpected_keys = [
|
||||
_keys_to_ignore_on_load_unexpected = [
|
||||
r"model.encoder.embed_tokens.weight",
|
||||
r"model.decoder.embed_tokens.weight",
|
||||
]
|
||||
|
@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel):
|
||||
config_class = BertConfig
|
||||
load_tf_weights = load_tf_weights_in_bert
|
||||
base_model_prefix = "bert"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
@ -969,8 +969,8 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
)
|
||||
class BertLMHeadModel(BertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1087,8 +1087,8 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1469,7 +1469,7 @@ class BertForMultipleChoice(BertPreTrainedModel):
|
||||
)
|
||||
class BertForTokenClassification(BertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1560,7 +1560,7 @@ class BertForTokenClassification(BertPreTrainedModel):
|
||||
)
|
||||
class BertForQuestionAnswering(BertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -938,8 +938,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss):
|
||||
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
|
||||
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1023,8 +1023,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1416,8 +1416,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss):
|
||||
)
|
||||
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1502,8 +1502,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL
|
||||
)
|
||||
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = BertGenerationConfig
|
||||
base_model_prefix = "bert"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
|
@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = DebertaConfig
|
||||
base_model_prefix = "deberta"
|
||||
authorized_missing_keys = ["position_ids"]
|
||||
_keys_to_ignore_on_load_missing = ["position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
|
@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
|
||||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "ctx_encoder"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.ctx_encoder.init_weights()
|
||||
@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
|
||||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "question_encoder"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.question_encoder.init_weights()
|
||||
@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel):
|
||||
config_class = DPRConfig
|
||||
load_tf_weights = None
|
||||
base_model_prefix = "span_predictor"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def init_weights(self):
|
||||
self.span_predictor.encoder.init_weights()
|
||||
|
@ -544,8 +544,8 @@ class ElectraPreTrainedModel(PreTrainedModel):
|
||||
config_class = ElectraConfig
|
||||
load_tf_weights = load_tf_weights_in_electra
|
||||
base_model_prefix = "electra"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
|
||||
def _init_weights(self, module):
|
||||
|
@ -1005,11 +1005,11 @@ class FSMTModel(PretrainedFSMTModel):
|
||||
)
|
||||
class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
base_model_prefix = "model"
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
_keys_to_ignore_on_save = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
|
@ -780,7 +780,7 @@ class GPT2Model(GPT2PreTrainedModel):
|
||||
GPT2_START_DOCSTRING,
|
||||
)
|
||||
class GPT2LMHeadModel(GPT2PreTrainedModel):
|
||||
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1097,7 +1097,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel):
|
||||
GPT2_START_DOCSTRING,
|
||||
)
|
||||
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
|
||||
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = LayoutLMConfig
|
||||
base_model_prefix = "layoutlm"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
|
@ -1303,7 +1303,7 @@ class LongformerPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = LongformerConfig
|
||||
base_model_prefix = "longformer"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
@ -1621,7 +1621,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
||||
@add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING)
|
||||
class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1718,7 +1718,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||
)
|
||||
class LongformerForSequenceClassification(LongformerPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1827,7 +1827,7 @@ class LongformerClassificationHead(nn.Module):
|
||||
)
|
||||
class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1961,7 +1961,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel):
|
||||
)
|
||||
class LongformerForTokenClassification(LongformerPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1961,7 +1961,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel):
|
||||
)
|
||||
class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -2048,7 +2048,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel
|
||||
)
|
||||
class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -2199,7 +2199,7 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer):
|
||||
)
|
||||
class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -2443,7 +2443,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic
|
||||
)
|
||||
class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
@ -47,11 +47,11 @@ class MarianMTModel(BartForConditionalGeneration):
|
||||
|
||||
"""
|
||||
config_class = MarianConfig
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
_keys_to_ignore_on_save = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
|
@ -37,7 +37,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
@add_start_docstrings("Marian model for machine translation", START_DOCSTRING)
|
||||
class TFMarianMTModel(TFBartForConditionalGeneration):
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"model.encoder.embed_positions.weight",
|
||||
r"model.decoder.embed_positions.weight",
|
||||
]
|
||||
|
@ -29,11 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration):
|
||||
"""
|
||||
model_type = "mbart"
|
||||
config_class = MBartConfig
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
_keys_to_ignore_on_save = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
|
@ -677,7 +677,7 @@ class MobileBertPreTrainedModel(PreTrainedModel):
|
||||
pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST
|
||||
load_tf_weights = load_tf_weights_in_mobilebert
|
||||
base_model_prefix = "mobilebert"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
@ -1054,7 +1054,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
||||
class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1350,7 +1350,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
|
||||
)
|
||||
class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1545,7 +1545,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
|
||||
)
|
||||
class MobileBertForTokenClassification(MobileBertPreTrainedModel):
|
||||
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1030,7 +1030,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel):
|
||||
@add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING)
|
||||
class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1297,7 +1297,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque
|
||||
)
|
||||
class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1529,7 +1529,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic
|
||||
)
|
||||
class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
@ -42,12 +42,12 @@ class MT5Model(T5Model):
|
||||
"""
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
_keys_to_ignore_on_save = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
]
|
||||
@ -71,13 +71,13 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration):
|
||||
|
||||
model_type = "mt5"
|
||||
config_class = MT5Config
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"lm_head\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
_keys_to_ignore_on_save = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
]
|
||||
|
@ -279,7 +279,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
|
||||
config_class = OpenAIGPTConfig
|
||||
load_tf_weights = load_tf_weights_in_openai_gpt
|
||||
base_model_prefix = "transformer"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
"""Initialize the weights."""
|
||||
|
@ -46,14 +46,14 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration):
|
||||
"""
|
||||
# All the code is in src/transformers/models/bart/modeling_bart.py
|
||||
config_class = PegasusConfig
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
r"encoder\.version",
|
||||
r"decoder\.version",
|
||||
"model.encoder.embed_positions",
|
||||
"model.decoder.embed_positions",
|
||||
]
|
||||
keys_to_never_save = [
|
||||
_keys_to_ignore_on_save = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
]
|
||||
|
@ -32,7 +32,7 @@ logger = logging.get_logger(__name__)
|
||||
|
||||
@add_start_docstrings("Pegasus model for summarization", START_DOCSTRING)
|
||||
class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration):
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
r"model.encoder.embed_positions.weight",
|
||||
r"model.decoder.embed_positions.weight",
|
||||
|
@ -216,7 +216,7 @@ class RagPreTrainedModel(PreTrainedModel):
|
||||
"""
|
||||
config_class = RagConfig
|
||||
base_model_prefix = "rag"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
@classmethod
|
||||
def from_pretrained_question_encoder_generator(
|
||||
|
@ -576,7 +576,7 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
|
||||
"""
|
||||
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta
|
||||
def __init__(self, config, add_pooling_layer=True):
|
||||
@ -711,8 +711,8 @@ class RobertaModel(RobertaPreTrainedModel):
|
||||
"""RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING
|
||||
)
|
||||
class RobertaForCausalLM(RobertaPreTrainedModel):
|
||||
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -829,8 +829,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||
class RobertaForMaskedLM(RobertaPreTrainedModel):
|
||||
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -948,7 +948,7 @@ class RobertaLMHead(nn.Module):
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1031,7 +1031,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel):
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class RobertaForMultipleChoice(RobertaPreTrainedModel):
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1123,8 +1123,8 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel):
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class RobertaForTokenClassification(RobertaPreTrainedModel):
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1233,8 +1233,8 @@ class RobertaClassificationHead(nn.Module):
|
||||
ROBERTA_START_DOCSTRING,
|
||||
)
|
||||
class RobertaForQuestionAnswering(RobertaPreTrainedModel):
|
||||
authorized_unexpected_keys = [r"pooler"]
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -765,7 +765,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer):
|
||||
@add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING)
|
||||
class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -877,7 +877,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer):
|
||||
)
|
||||
class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1084,7 +1084,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss)
|
||||
)
|
||||
class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
@ -1171,7 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific
|
||||
)
|
||||
class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss):
|
||||
|
||||
authorized_missing_keys = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"pooler"]
|
||||
|
||||
def __init__(self, config, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
|
@ -428,7 +428,7 @@ class SqueezeBertPreTrainedModel(PreTrainedModel):
|
||||
|
||||
config_class = SqueezeBertConfig
|
||||
base_model_prefix = "transformer"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
@ -642,7 +642,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
|
||||
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING)
|
||||
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
||||
|
||||
authorized_missing_keys = [r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1086,7 +1086,7 @@ T5_INPUTS_DOCSTRING = r"""
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class T5Model(T5PreTrainedModel):
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
|
||||
@ -1258,7 +1258,7 @@ class T5Model(T5PreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING)
|
||||
class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
authorized_missing_keys = [
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"encoder\.embed_tokens\.weight",
|
||||
r"decoder\.embed_tokens\.weight",
|
||||
r"lm_head\.weight",
|
||||
|
@ -399,7 +399,7 @@ XLM_INPUTS_DOCSTRING = r"""
|
||||
XLM_START_DOCSTRING,
|
||||
)
|
||||
class XLMModel(XLMPreTrainedModel):
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -540,7 +540,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel):
|
||||
config_class = {{cookiecutter.camelcase_modelname}}Config
|
||||
load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}}
|
||||
base_model_prefix = "{{cookiecutter.lowercase_modelname}}"
|
||||
authorized_missing_keys = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
""" Initialize the weights """
|
||||
|
@ -135,17 +135,17 @@ class ModelTesterMixin:
|
||||
max_diff = np.amax(np.abs(out_1 - out_2))
|
||||
self.assertLessEqual(max_diff, 1e-5)
|
||||
|
||||
def test_save_load_keys_to_never_save(self):
|
||||
def test_save_load__keys_to_ignore_on_save(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
model = model_class(config)
|
||||
keys_to_never_save = getattr(model, "keys_to_never_save", None)
|
||||
if keys_to_never_save is None:
|
||||
_keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None)
|
||||
if _keys_to_ignore_on_save is None:
|
||||
continue
|
||||
|
||||
# check the keys are in the original state_dict
|
||||
for k in keys_to_never_save:
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertIn(k, model.state_dict())
|
||||
|
||||
# check that certain keys didn't get saved with the model
|
||||
@ -153,7 +153,7 @@ class ModelTesterMixin:
|
||||
model.save_pretrained(tmpdirname)
|
||||
output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME)
|
||||
state_dict_saved = torch.load(output_model_file)
|
||||
for k in keys_to_never_save:
|
||||
for k in _keys_to_ignore_on_save:
|
||||
self.assertNotIn(k, state_dict_saved)
|
||||
|
||||
def test_initialization(self):
|
||||
|
@ -60,7 +60,7 @@ class ModelTester:
|
||||
class SelectiveCommonTest(unittest.TestCase):
|
||||
all_model_classes = (MarianMTModel,) if is_torch_available() else ()
|
||||
|
||||
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
|
||||
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
|
@ -47,7 +47,7 @@ class ModelTester:
|
||||
class SelectiveCommonTest(unittest.TestCase):
|
||||
all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else ()
|
||||
|
||||
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
|
||||
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
|
@ -43,7 +43,7 @@ class ModelTester:
|
||||
class SelectiveCommonTest(unittest.TestCase):
|
||||
all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else ()
|
||||
|
||||
test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save
|
||||
test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = ModelTester(self)
|
||||
|
Loading…
Reference in New Issue
Block a user