From e84786aaa69f9013ed596cf1d368d12999903005 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Mon, 23 Nov 2020 12:33:13 -0800 Subject: [PATCH] consistent ignore keys + make private (#8737) * consistent ignore keys + make private * style * - authorized_missing_keys => _keys_to_ignore_on_load_missing - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected * move public doc of private attributes to private comment --- src/transformers/modeling_tf_pytorch_utils.py | 12 ++++---- src/transformers/modeling_tf_utils.py | 20 ++++++------- src/transformers/modeling_utils.py | 29 ++++++++++--------- .../models/albert/modeling_albert.py | 8 ++--- .../models/albert/modeling_tf_albert.py | 6 ++-- src/transformers/models/bart/modeling_bart.py | 2 +- .../models/bart/modeling_tf_bart.py | 4 +-- src/transformers/models/bert/modeling_bert.py | 14 ++++----- .../models/bert/modeling_tf_bert.py | 16 +++++----- .../modeling_bert_generation.py | 2 +- .../models/deberta/modeling_deberta.py | 2 +- src/transformers/models/dpr/modeling_dpr.py | 6 ++-- .../models/electra/modeling_electra.py | 4 +-- src/transformers/models/fsmt/modeling_fsmt.py | 4 +-- src/transformers/models/gpt2/modeling_gpt2.py | 4 +-- .../models/layoutlm/modeling_layoutlm.py | 2 +- .../models/longformer/modeling_longformer.py | 10 +++---- .../longformer/modeling_tf_longformer.py | 8 ++--- .../models/marian/modeling_marian.py | 4 +-- .../models/marian/modeling_tf_marian.py | 2 +- .../models/mbart/modeling_mbart.py | 4 +-- .../models/mobilebert/modeling_mobilebert.py | 8 ++--- .../mobilebert/modeling_tf_mobilebert.py | 6 ++-- src/transformers/models/mt5/modeling_mt5.py | 8 ++--- .../models/openai/modeling_openai.py | 2 +- .../models/pegasus/modeling_pegasus.py | 4 +-- .../models/pegasus/modeling_tf_pegasus.py | 2 +- src/transformers/models/rag/modeling_rag.py | 2 +- .../models/roberta/modeling_roberta.py | 22 +++++++------- .../models/roberta/modeling_tf_roberta.py | 8 ++--- .../squeezebert/modeling_squeezebert.py | 4 +-- src/transformers/models/t5/modeling_t5.py | 4 +-- src/transformers/models/xlm/modeling_xlm.py | 2 +- ...ng_{{cookiecutter.lowercase_modelname}}.py | 2 +- tests/test_modeling_common.py | 10 +++---- tests/test_modeling_marian.py | 2 +- tests/test_modeling_mbart.py | 2 +- tests/test_modeling_pegasus.py | 2 +- 38 files changed, 127 insertions(+), 126 deletions(-) diff --git a/src/transformers/modeling_tf_pytorch_utils.py b/src/transformers/modeling_tf_pytorch_utils.py index 5392d321576..761cf7d721e 100644 --- a/src/transformers/modeling_tf_pytorch_utils.py +++ b/src/transformers/modeling_tf_pytorch_utils.py @@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a if allow_missing_keys: missing_keys.append(name) continue - elif tf_model.authorized_missing_keys is not None: + elif tf_model._keys_to_ignore_on_load_missing is not None: # authorized missing keys don't have to be loaded - if any(re.search(pat, name) is not None for pat in tf_model.authorized_missing_keys): + if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing): continue raise AttributeError("{} not found in PyTorch model".format(name)) @@ -209,11 +209,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a unexpected_keys = list(all_pytorch_weights) - if tf_model.authorized_missing_keys is not None: - for pat in tf_model.authorized_missing_keys: + if tf_model._keys_to_ignore_on_load_missing is not None: + for pat in tf_model._keys_to_ignore_on_load_missing: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - if tf_model.authorized_unexpected_keys is not None: - for pat in tf_model.authorized_unexpected_keys: + if tf_model._keys_to_ignore_on_load_unexpected is not None: + for pat in tf_model._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: diff --git a/src/transformers/modeling_tf_utils.py b/src/transformers/modeling_tf_utils.py index 6a6a6aa6c55..c8c8dfee482 100644 --- a/src/transformers/modeling_tf_utils.py +++ b/src/transformers/modeling_tf_utils.py @@ -343,15 +343,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): :class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture. - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. - - **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore - from the model when loading the model weights (and avoid unnecessary warnings). - - **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to - ignore from the weights when loading the model weights (and avoid unnecessary warnings). """ config_class = None base_model_prefix = "" - authorized_missing_keys = None - authorized_unexpected_keys = None + # a list of re pattern of tensor names to ignore from the model when loading the model weights + # (and avoid unnecessary warnings). + _keys_to_ignore_on_load_missing = None + # a list of re pattern of tensor names to ignore from the weights when loading the model weights + # (and avoid unnecessary warnings). + _keys_to_ignore_on_load_unexpected = None @property def dummy_inputs(self) -> Dict[str, tf.Tensor]: @@ -742,12 +742,12 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin): model(model.dummy_inputs, training=False) # Make sure restore ops are run - if cls.authorized_missing_keys is not None: - for pat in cls.authorized_missing_keys: + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - if cls.authorized_unexpected_keys is not None: - for pat in cls.authorized_unexpected_keys: + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index ead4f192eec..a5166aefc25 100755 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): - **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in derived classes of the same architecture adding modules on top of the base model. - - **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore - when loading the model (and avoid unnecessary warnings). - - **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the - model (useful for keys that aren't trained, but which are deterministic) - """ config_class = None base_model_prefix = "" - authorized_missing_keys = None - authorized_unexpected_keys = None - keys_to_never_save = None + # a list of re pattern of tensor names to ignore from the model when loading the model weights + # (and avoid unnecessary warnings). + _keys_to_ignore_on_load_missing = None + # a list of re pattern of tensor names to ignore from the weights when loading the model weights + # (and avoid unnecessary warnings). + _keys_to_ignore_on_load_unexpected = None + # a list of of tensor names to ignore when saving the model (useful for keys that aren't + # trained, but which are deterministic) + _keys_to_ignore_on_save = None @property def dummy_inputs(self) -> Dict[str, torch.Tensor]: @@ -719,8 +720,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): state_dict = model_to_save.state_dict() # Handle the case where some state_dict keys shouldn't be saved - if self.keys_to_never_save is not None: - state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save} + if self._keys_to_ignore_on_save is not None: + state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save} # If we save using the predefined names, we can load using `from_pretrained` output_model_file = os.path.join(save_directory, WEIGHTS_NAME) @@ -1034,12 +1035,12 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin): # Some models may have keys that are not in the state by design, removing them before needlessly warning # the user. - if cls.authorized_missing_keys is not None: - for pat in cls.authorized_missing_keys: + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: missing_keys = [k for k in missing_keys if re.search(pat, k) is None] - if cls.authorized_unexpected_keys is not None: - for pat in cls.authorized_unexpected_keys: + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None] if len(unexpected_keys) > 0: diff --git a/src/transformers/models/albert/modeling_albert.py b/src/transformers/models/albert/modeling_albert.py index 140c122bad1..0989e4a3674 100755 --- a/src/transformers/models/albert/modeling_albert.py +++ b/src/transformers/models/albert/modeling_albert.py @@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel): config_class = AlbertConfig base_model_prefix = "albert" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights.""" @@ -851,7 +851,7 @@ class AlbertSOPHead(nn.Module): ) class AlbertForMaskedLM(AlbertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -1021,7 +1021,7 @@ class AlbertForSequenceClassification(AlbertPreTrainedModel): ) class AlbertForTokenClassification(AlbertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -1110,7 +1110,7 @@ class AlbertForTokenClassification(AlbertPreTrainedModel): ) class AlbertForQuestionAnswering(AlbertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/albert/modeling_tf_albert.py b/src/transformers/models/albert/modeling_tf_albert.py index ccbaab009cc..580b340dd57 100644 --- a/src/transformers/models/albert/modeling_tf_albert.py +++ b/src/transformers/models/albert/modeling_tf_albert.py @@ -843,7 +843,7 @@ class TFAlbertSOPHead(tf.keras.layers.Layer): @add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING) class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1013,7 +1013,7 @@ class TFAlbertForSequenceClassification(TFAlbertPreTrainedModel, TFSequenceClass ) class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1100,7 +1100,7 @@ class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificat ) class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index df0090c28e1..afa006a5428 100644 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -946,7 +946,7 @@ class BartModel(PretrainedBartModel): ) class BartForConditionalGeneration(PretrainedBartModel): base_model_prefix = "model" - authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] + _keys_to_ignore_on_load_missing = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"] def __init__(self, config: BartConfig): super().__init__(config) diff --git a/src/transformers/models/bart/modeling_tf_bart.py b/src/transformers/models/bart/modeling_tf_bart.py index b9c5b429754..e8bf4c7de7d 100644 --- a/src/transformers/models/bart/modeling_tf_bart.py +++ b/src/transformers/models/bart/modeling_tf_bart.py @@ -1020,10 +1020,10 @@ class TFBartModel(TFPretrainedBartModel): ) class TFBartForConditionalGeneration(TFPretrainedBartModel): base_model_prefix = "model" - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ r"final_logits_bias", ] - authorized_unexpected_keys = [ + _keys_to_ignore_on_load_unexpected = [ r"model.encoder.embed_tokens.weight", r"model.decoder.embed_tokens.weight", ] diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py index a6bdf641553..6352a7be058 100755 --- a/src/transformers/models/bert/modeling_bert.py +++ b/src/transformers/models/bert/modeling_bert.py @@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel): config_class = BertConfig load_tf_weights = load_tf_weights_in_bert base_model_prefix = "bert" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights """ @@ -969,8 +969,8 @@ class BertForPreTraining(BertPreTrainedModel): ) class BertLMHeadModel(BertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] - authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) @@ -1087,8 +1087,8 @@ class BertLMHeadModel(BertPreTrainedModel): @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) class BertForMaskedLM(BertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] - authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) @@ -1469,7 +1469,7 @@ class BertForMultipleChoice(BertPreTrainedModel): ) class BertForTokenClassification(BertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -1560,7 +1560,7 @@ class BertForTokenClassification(BertPreTrainedModel): ) class BertForQuestionAnswering(BertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/bert/modeling_tf_bert.py b/src/transformers/models/bert/modeling_tf_bert.py index f6b9d81d269..53e054f57c4 100644 --- a/src/transformers/models/bert/modeling_tf_bert.py +++ b/src/transformers/models/bert/modeling_tf_bert.py @@ -938,8 +938,8 @@ class TFBertForPreTraining(TFBertPreTrainedModel, TFBertPreTrainingLoss): @add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING) class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): - authorized_unexpected_keys = [r"pooler"] - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1023,8 +1023,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss): class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss): - authorized_unexpected_keys = [r"pooler"] - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1416,8 +1416,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel, TFMultipleChoiceLoss): ) class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss): - authorized_unexpected_keys = [r"pooler"] - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1502,8 +1502,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationL ) class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss): - authorized_unexpected_keys = [r"pooler"] - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/bert_generation/modeling_bert_generation.py b/src/transformers/models/bert_generation/modeling_bert_generation.py index 9ab4d1ee4de..7efe4422ad7 100755 --- a/src/transformers/models/bert_generation/modeling_bert_generation.py +++ b/src/transformers/models/bert_generation/modeling_bert_generation.py @@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel): config_class = BertGenerationConfig base_model_prefix = "bert" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights """ diff --git a/src/transformers/models/deberta/modeling_deberta.py b/src/transformers/models/deberta/modeling_deberta.py index 00ae44aa432..47a6b18281f 100644 --- a/src/transformers/models/deberta/modeling_deberta.py +++ b/src/transformers/models/deberta/modeling_deberta.py @@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel): config_class = DebertaConfig base_model_prefix = "deberta" - authorized_missing_keys = ["position_ids"] + _keys_to_ignore_on_load_missing = ["position_ids"] def _init_weights(self, module): """ Initialize the weights """ diff --git a/src/transformers/models/dpr/modeling_dpr.py b/src/transformers/models/dpr/modeling_dpr.py index 5d5763137bb..2fbe4a06768 100644 --- a/src/transformers/models/dpr/modeling_dpr.py +++ b/src/transformers/models/dpr/modeling_dpr.py @@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel): config_class = DPRConfig load_tf_weights = None base_model_prefix = "ctx_encoder" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def init_weights(self): self.ctx_encoder.init_weights() @@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel): config_class = DPRConfig load_tf_weights = None base_model_prefix = "question_encoder" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def init_weights(self): self.question_encoder.init_weights() @@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel): config_class = DPRConfig load_tf_weights = None base_model_prefix = "span_predictor" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def init_weights(self): self.span_predictor.encoder.init_weights() diff --git a/src/transformers/models/electra/modeling_electra.py b/src/transformers/models/electra/modeling_electra.py index 3a4903cd26d..a2d3b5be141 100644 --- a/src/transformers/models/electra/modeling_electra.py +++ b/src/transformers/models/electra/modeling_electra.py @@ -544,8 +544,8 @@ class ElectraPreTrainedModel(PreTrainedModel): config_class = ElectraConfig load_tf_weights = load_tf_weights_in_electra base_model_prefix = "electra" - authorized_missing_keys = [r"position_ids"] - authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] + _keys_to_ignore_on_load_missing = [r"position_ids"] + _keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"] # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights def _init_weights(self, module): diff --git a/src/transformers/models/fsmt/modeling_fsmt.py b/src/transformers/models/fsmt/modeling_fsmt.py index 27b7737232e..0c9337e30a4 100644 --- a/src/transformers/models/fsmt/modeling_fsmt.py +++ b/src/transformers/models/fsmt/modeling_fsmt.py @@ -1005,11 +1005,11 @@ class FSMTModel(PretrainedFSMTModel): ) class FSMTForConditionalGeneration(PretrainedFSMTModel): base_model_prefix = "model" - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ "model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight", ] - keys_to_never_save = [ + _keys_to_ignore_on_save = [ "model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight", ] diff --git a/src/transformers/models/gpt2/modeling_gpt2.py b/src/transformers/models/gpt2/modeling_gpt2.py index 39b40a1e54b..12c9d143690 100644 --- a/src/transformers/models/gpt2/modeling_gpt2.py +++ b/src/transformers/models/gpt2/modeling_gpt2.py @@ -780,7 +780,7 @@ class GPT2Model(GPT2PreTrainedModel): GPT2_START_DOCSTRING, ) class GPT2LMHeadModel(GPT2PreTrainedModel): - authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] def __init__(self, config): super().__init__(config) @@ -1097,7 +1097,7 @@ class GPT2DoubleHeadsModel(GPT2PreTrainedModel): GPT2_START_DOCSTRING, ) class GPT2ForSequenceClassification(GPT2PreTrainedModel): - authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/layoutlm/modeling_layoutlm.py b/src/transformers/models/layoutlm/modeling_layoutlm.py index f75eb701008..3c48d436f14 100644 --- a/src/transformers/models/layoutlm/modeling_layoutlm.py +++ b/src/transformers/models/layoutlm/modeling_layoutlm.py @@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel): config_class = LayoutLMConfig base_model_prefix = "layoutlm" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights """ diff --git a/src/transformers/models/longformer/modeling_longformer.py b/src/transformers/models/longformer/modeling_longformer.py index 619db76f0e9..9e4a5a28fb8 100755 --- a/src/transformers/models/longformer/modeling_longformer.py +++ b/src/transformers/models/longformer/modeling_longformer.py @@ -1303,7 +1303,7 @@ class LongformerPreTrainedModel(PreTrainedModel): config_class = LongformerConfig base_model_prefix = "longformer" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights """ @@ -1621,7 +1621,7 @@ class LongformerModel(LongformerPreTrainedModel): @add_start_docstrings("""Longformer Model with a `language modeling` head on top. """, LONGFORMER_START_DOCSTRING) class LongformerForMaskedLM(LongformerPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -1718,7 +1718,7 @@ class LongformerForMaskedLM(LongformerPreTrainedModel): ) class LongformerForSequenceClassification(LongformerPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -1827,7 +1827,7 @@ class LongformerClassificationHead(nn.Module): ) class LongformerForQuestionAnswering(LongformerPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -1961,7 +1961,7 @@ class LongformerForQuestionAnswering(LongformerPreTrainedModel): ) class LongformerForTokenClassification(LongformerPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/longformer/modeling_tf_longformer.py b/src/transformers/models/longformer/modeling_tf_longformer.py index 6fd30dfd796..e2057ed8fb4 100644 --- a/src/transformers/models/longformer/modeling_tf_longformer.py +++ b/src/transformers/models/longformer/modeling_tf_longformer.py @@ -1961,7 +1961,7 @@ class TFLongformerModel(TFLongformerPreTrainedModel): ) class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModelingLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -2048,7 +2048,7 @@ class TFLongformerForMaskedLM(TFLongformerPreTrainedModel, TFMaskedLanguageModel ) class TFLongformerForQuestionAnswering(TFLongformerPreTrainedModel, TFQuestionAnsweringLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -2199,7 +2199,7 @@ class TFLongformerClassificationHead(tf.keras.layers.Layer): ) class TFLongformerForSequenceClassification(TFLongformerPreTrainedModel, TFSequenceClassificationLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -2443,7 +2443,7 @@ class TFLongformerForMultipleChoice(TFLongformerPreTrainedModel, TFMultipleChoic ) class TFLongformerForTokenClassification(TFLongformerPreTrainedModel, TFTokenClassificationLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/marian/modeling_marian.py b/src/transformers/models/marian/modeling_marian.py index 637529c1168..25d3dc1ea96 100644 --- a/src/transformers/models/marian/modeling_marian.py +++ b/src/transformers/models/marian/modeling_marian.py @@ -47,11 +47,11 @@ class MarianMTModel(BartForConditionalGeneration): """ config_class = MarianConfig - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ "model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight", ] - keys_to_never_save = [ + _keys_to_ignore_on_save = [ "model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight", ] diff --git a/src/transformers/models/marian/modeling_tf_marian.py b/src/transformers/models/marian/modeling_tf_marian.py index e385e5f6e5e..f17182306ee 100644 --- a/src/transformers/models/marian/modeling_tf_marian.py +++ b/src/transformers/models/marian/modeling_tf_marian.py @@ -37,7 +37,7 @@ logger = logging.get_logger(__name__) @add_start_docstrings("Marian model for machine translation", START_DOCSTRING) class TFMarianMTModel(TFBartForConditionalGeneration): - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ r"model.encoder.embed_positions.weight", r"model.decoder.embed_positions.weight", ] diff --git a/src/transformers/models/mbart/modeling_mbart.py b/src/transformers/models/mbart/modeling_mbart.py index 2978a250dcb..de19f46e3a9 100644 --- a/src/transformers/models/mbart/modeling_mbart.py +++ b/src/transformers/models/mbart/modeling_mbart.py @@ -29,11 +29,11 @@ class MBartForConditionalGeneration(BartForConditionalGeneration): """ model_type = "mbart" config_class = MBartConfig - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ "model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight", ] - keys_to_never_save = [ + _keys_to_ignore_on_save = [ "model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight", ] diff --git a/src/transformers/models/mobilebert/modeling_mobilebert.py b/src/transformers/models/mobilebert/modeling_mobilebert.py index 3628f80871d..5bb5c6353ac 100644 --- a/src/transformers/models/mobilebert/modeling_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_mobilebert.py @@ -677,7 +677,7 @@ class MobileBertPreTrainedModel(PreTrainedModel): pretrained_model_archive_map = MOBILEBERT_PRETRAINED_MODEL_ARCHIVE_LIST load_tf_weights = load_tf_weights_in_mobilebert base_model_prefix = "mobilebert" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights """ @@ -1054,7 +1054,7 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel): @add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING) class MobileBertForMaskedLM(MobileBertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -1350,7 +1350,7 @@ class MobileBertForSequenceClassification(MobileBertPreTrainedModel): ) class MobileBertForQuestionAnswering(MobileBertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -1545,7 +1545,7 @@ class MobileBertForMultipleChoice(MobileBertPreTrainedModel): ) class MobileBertForTokenClassification(MobileBertPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py index a776230f276..fa75fe9d4e7 100644 --- a/src/transformers/models/mobilebert/modeling_tf_mobilebert.py +++ b/src/transformers/models/mobilebert/modeling_tf_mobilebert.py @@ -1030,7 +1030,7 @@ class TFMobileBertForPreTraining(TFMobileBertPreTrainedModel): @add_start_docstrings("""MobileBert Model with a `language modeling` head on top. """, MOBILEBERT_START_DOCSTRING) class TFMobileBertForMaskedLM(TFMobileBertPreTrainedModel, TFMaskedLanguageModelingLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1297,7 +1297,7 @@ class TFMobileBertForSequenceClassification(TFMobileBertPreTrainedModel, TFSeque ) class TFMobileBertForQuestionAnswering(TFMobileBertPreTrainedModel, TFQuestionAnsweringLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1529,7 +1529,7 @@ class TFMobileBertForMultipleChoice(TFMobileBertPreTrainedModel, TFMultipleChoic ) class TFMobileBertForTokenClassification(TFMobileBertPreTrainedModel, TFTokenClassificationLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/mt5/modeling_mt5.py b/src/transformers/models/mt5/modeling_mt5.py index 10d64faf305..3cffaff4254 100644 --- a/src/transformers/models/mt5/modeling_mt5.py +++ b/src/transformers/models/mt5/modeling_mt5.py @@ -42,12 +42,12 @@ class MT5Model(T5Model): """ model_type = "mt5" config_class = MT5Config - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", ] - keys_to_never_save = [ + _keys_to_ignore_on_save = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", ] @@ -71,13 +71,13 @@ class MT5ForConditionalGeneration(T5ForConditionalGeneration): model_type = "mt5" config_class = MT5Config - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", ] - keys_to_never_save = [ + _keys_to_ignore_on_save = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", ] diff --git a/src/transformers/models/openai/modeling_openai.py b/src/transformers/models/openai/modeling_openai.py index 18f0a1f687c..3d8df216291 100644 --- a/src/transformers/models/openai/modeling_openai.py +++ b/src/transformers/models/openai/modeling_openai.py @@ -279,7 +279,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel): config_class = OpenAIGPTConfig load_tf_weights = load_tf_weights_in_openai_gpt base_model_prefix = "transformer" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """Initialize the weights.""" diff --git a/src/transformers/models/pegasus/modeling_pegasus.py b/src/transformers/models/pegasus/modeling_pegasus.py index 64515c7a8ba..3e623a77040 100644 --- a/src/transformers/models/pegasus/modeling_pegasus.py +++ b/src/transformers/models/pegasus/modeling_pegasus.py @@ -46,14 +46,14 @@ class PegasusForConditionalGeneration(BartForConditionalGeneration): """ # All the code is in src/transformers/models/bart/modeling_bart.py config_class = PegasusConfig - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ r"final_logits_bias", r"encoder\.version", r"decoder\.version", "model.encoder.embed_positions", "model.decoder.embed_positions", ] - keys_to_never_save = [ + _keys_to_ignore_on_save = [ "model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight", ] diff --git a/src/transformers/models/pegasus/modeling_tf_pegasus.py b/src/transformers/models/pegasus/modeling_tf_pegasus.py index 7f53dba8e00..bec856575d1 100644 --- a/src/transformers/models/pegasus/modeling_tf_pegasus.py +++ b/src/transformers/models/pegasus/modeling_tf_pegasus.py @@ -32,7 +32,7 @@ logger = logging.get_logger(__name__) @add_start_docstrings("Pegasus model for summarization", START_DOCSTRING) class TFPegasusForConditionalGeneration(TFBartForConditionalGeneration): - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ r"final_logits_bias", r"model.encoder.embed_positions.weight", r"model.decoder.embed_positions.weight", diff --git a/src/transformers/models/rag/modeling_rag.py b/src/transformers/models/rag/modeling_rag.py index f96868cb543..1c996a01189 100644 --- a/src/transformers/models/rag/modeling_rag.py +++ b/src/transformers/models/rag/modeling_rag.py @@ -216,7 +216,7 @@ class RagPreTrainedModel(PreTrainedModel): """ config_class = RagConfig base_model_prefix = "rag" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] @classmethod def from_pretrained_question_encoder_generator( diff --git a/src/transformers/models/roberta/modeling_roberta.py b/src/transformers/models/roberta/modeling_roberta.py index d322bbb0918..072b0bd8da6 100644 --- a/src/transformers/models/roberta/modeling_roberta.py +++ b/src/transformers/models/roberta/modeling_roberta.py @@ -576,7 +576,7 @@ class RobertaModel(RobertaPreTrainedModel): """ - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] # Copied from transformers.models.bert.modeling_bert.BertModel.__init__ with Bert->Roberta def __init__(self, config, add_pooling_layer=True): @@ -711,8 +711,8 @@ class RobertaModel(RobertaPreTrainedModel): """RoBERTa Model with a `language modeling` head on top for CLM fine-tuning. """, ROBERTA_START_DOCSTRING ) class RobertaForCausalLM(RobertaPreTrainedModel): - authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -829,8 +829,8 @@ class RobertaForCausalLM(RobertaPreTrainedModel): @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) class RobertaForMaskedLM(RobertaPreTrainedModel): - authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"] - authorized_unexpected_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] def __init__(self, config): super().__init__(config) @@ -948,7 +948,7 @@ class RobertaLMHead(nn.Module): ROBERTA_START_DOCSTRING, ) class RobertaForSequenceClassification(RobertaPreTrainedModel): - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) @@ -1031,7 +1031,7 @@ class RobertaForSequenceClassification(RobertaPreTrainedModel): ROBERTA_START_DOCSTRING, ) class RobertaForMultipleChoice(RobertaPreTrainedModel): - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) @@ -1123,8 +1123,8 @@ class RobertaForMultipleChoice(RobertaPreTrainedModel): ROBERTA_START_DOCSTRING, ) class RobertaForTokenClassification(RobertaPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) @@ -1233,8 +1233,8 @@ class RobertaClassificationHead(nn.Module): ROBERTA_START_DOCSTRING, ) class RobertaForQuestionAnswering(RobertaPreTrainedModel): - authorized_unexpected_keys = [r"pooler"] - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_unexpected = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/roberta/modeling_tf_roberta.py b/src/transformers/models/roberta/modeling_tf_roberta.py index 9bfb2954d0f..9a4d8828c02 100644 --- a/src/transformers/models/roberta/modeling_tf_roberta.py +++ b/src/transformers/models/roberta/modeling_tf_roberta.py @@ -765,7 +765,7 @@ class TFRobertaLMHead(tf.keras.layers.Layer): @add_start_docstrings("""RoBERTa Model with a `language modeling` head on top. """, ROBERTA_START_DOCSTRING) class TFRobertaForMaskedLM(TFRobertaPreTrainedModel, TFMaskedLanguageModelingLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -877,7 +877,7 @@ class TFRobertaClassificationHead(tf.keras.layers.Layer): ) class TFRobertaForSequenceClassification(TFRobertaPreTrainedModel, TFSequenceClassificationLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1084,7 +1084,7 @@ class TFRobertaForMultipleChoice(TFRobertaPreTrainedModel, TFMultipleChoiceLoss) ) class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassificationLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) @@ -1171,7 +1171,7 @@ class TFRobertaForTokenClassification(TFRobertaPreTrainedModel, TFTokenClassific ) class TFRobertaForQuestionAnswering(TFRobertaPreTrainedModel, TFQuestionAnsweringLoss): - authorized_missing_keys = [r"pooler"] + _keys_to_ignore_on_load_missing = [r"pooler"] def __init__(self, config, *inputs, **kwargs): super().__init__(config, *inputs, **kwargs) diff --git a/src/transformers/models/squeezebert/modeling_squeezebert.py b/src/transformers/models/squeezebert/modeling_squeezebert.py index ba61c3e70f7..56a40d143e9 100644 --- a/src/transformers/models/squeezebert/modeling_squeezebert.py +++ b/src/transformers/models/squeezebert/modeling_squeezebert.py @@ -428,7 +428,7 @@ class SqueezeBertPreTrainedModel(PreTrainedModel): config_class = SqueezeBertConfig base_model_prefix = "transformer" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights """ @@ -642,7 +642,7 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel): @add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top. """, SQUEEZEBERT_START_DOCSTRING) class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel): - authorized_missing_keys = [r"predictions.decoder.bias"] + _keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"] def __init__(self, config): super().__init__(config) diff --git a/src/transformers/models/t5/modeling_t5.py b/src/transformers/models/t5/modeling_t5.py index 7b28d2590cc..adba0b79fce 100644 --- a/src/transformers/models/t5/modeling_t5.py +++ b/src/transformers/models/t5/modeling_t5.py @@ -1086,7 +1086,7 @@ T5_INPUTS_DOCSTRING = r""" T5_START_DOCSTRING, ) class T5Model(T5PreTrainedModel): - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", @@ -1258,7 +1258,7 @@ class T5Model(T5PreTrainedModel): @add_start_docstrings("""T5 Model with a `language modeling` head on top. """, T5_START_DOCSTRING) class T5ForConditionalGeneration(T5PreTrainedModel): - authorized_missing_keys = [ + _keys_to_ignore_on_load_missing = [ r"encoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight", r"lm_head\.weight", diff --git a/src/transformers/models/xlm/modeling_xlm.py b/src/transformers/models/xlm/modeling_xlm.py index a144d58c735..b0667edab76 100755 --- a/src/transformers/models/xlm/modeling_xlm.py +++ b/src/transformers/models/xlm/modeling_xlm.py @@ -399,7 +399,7 @@ XLM_INPUTS_DOCSTRING = r""" XLM_START_DOCSTRING, ) class XLMModel(XLMPreTrainedModel): - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def __init__(self, config): super().__init__(config) diff --git a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py index 6036f8bc4e6..de898390fcf 100755 --- a/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py +++ b/templates/adding_a_new_model/cookiecutter-template-{{cookiecutter.modelname}}/modeling_{{cookiecutter.lowercase_modelname}}.py @@ -540,7 +540,7 @@ class {{cookiecutter.camelcase_modelname}}PreTrainedModel(PreTrainedModel): config_class = {{cookiecutter.camelcase_modelname}}Config load_tf_weights = load_tf_weights_in_{{cookiecutter.lowercase_modelname}} base_model_prefix = "{{cookiecutter.lowercase_modelname}}" - authorized_missing_keys = [r"position_ids"] + _keys_to_ignore_on_load_missing = [r"position_ids"] def _init_weights(self, module): """ Initialize the weights """ diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index f4cc7d6958e..6740761cf2d 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -135,17 +135,17 @@ class ModelTesterMixin: max_diff = np.amax(np.abs(out_1 - out_2)) self.assertLessEqual(max_diff, 1e-5) - def test_save_load_keys_to_never_save(self): + def test_save_load__keys_to_ignore_on_save(self): config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() for model_class in self.all_model_classes: model = model_class(config) - keys_to_never_save = getattr(model, "keys_to_never_save", None) - if keys_to_never_save is None: + _keys_to_ignore_on_save = getattr(model, "_keys_to_ignore_on_save", None) + if _keys_to_ignore_on_save is None: continue # check the keys are in the original state_dict - for k in keys_to_never_save: + for k in _keys_to_ignore_on_save: self.assertIn(k, model.state_dict()) # check that certain keys didn't get saved with the model @@ -153,7 +153,7 @@ class ModelTesterMixin: model.save_pretrained(tmpdirname) output_model_file = os.path.join(tmpdirname, WEIGHTS_NAME) state_dict_saved = torch.load(output_model_file) - for k in keys_to_never_save: + for k in _keys_to_ignore_on_save: self.assertNotIn(k, state_dict_saved) def test_initialization(self): diff --git a/tests/test_modeling_marian.py b/tests/test_modeling_marian.py index dc50daa9a78..3fc3338fec6 100644 --- a/tests/test_modeling_marian.py +++ b/tests/test_modeling_marian.py @@ -60,7 +60,7 @@ class ModelTester: class SelectiveCommonTest(unittest.TestCase): all_model_classes = (MarianMTModel,) if is_torch_available() else () - test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save + test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save def setUp(self): self.model_tester = ModelTester(self) diff --git a/tests/test_modeling_mbart.py b/tests/test_modeling_mbart.py index 8bb874613e9..c394e62f046 100644 --- a/tests/test_modeling_mbart.py +++ b/tests/test_modeling_mbart.py @@ -47,7 +47,7 @@ class ModelTester: class SelectiveCommonTest(unittest.TestCase): all_model_classes = (MBartForConditionalGeneration,) if is_torch_available() else () - test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save + test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save def setUp(self): self.model_tester = ModelTester(self) diff --git a/tests/test_modeling_pegasus.py b/tests/test_modeling_pegasus.py index 61435270119..96b2d403dcf 100644 --- a/tests/test_modeling_pegasus.py +++ b/tests/test_modeling_pegasus.py @@ -43,7 +43,7 @@ class ModelTester: class SelectiveCommonTest(unittest.TestCase): all_model_classes = (PegasusForConditionalGeneration,) if is_torch_available() else () - test_save_load_keys_to_never_save = ModelTesterMixin.test_save_load_keys_to_never_save + test_save_load__keys_to_ignore_on_save = ModelTesterMixin.test_save_load__keys_to_ignore_on_save def setUp(self): self.model_tester = ModelTester(self)