mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Attempting to test automatically the _keys_to_ignore
. (#20042)
* Attempting to test automatically the `_keys_to_ignore`. * Style. * First fix pass. * Moving test on its own. * Another batch. * Second round removing BatchNorm * Fixing layoutlmv{2,3} + support older Python. * Disable miss missing warning. * Removing dodgy additions. * Big pass. * mbart. * More corrections. * Fixup. * Updating test_correct_missing_keys * Add escape hatch for when the head has no extra params so doesn't need the missing keys check. * Fixing test. * Greener. * Green ! (except for weird splinter bug). * Adding a test about `named_parameters` usage. * Shorten message. * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * After rebase modifications. * More explicit condition checking. * Fixing slow tests issues. * Remove extra pdb. * Remove print. * Attempt to make failure consistent + fixing roc_bert. * Removing the seed (all tests passing with it). Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This commit is contained in:
parent
d606d566ab
commit
bac2d29a80
@ -2421,8 +2421,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
add_prefix_to_model = has_prefix_module and not expects_prefix_module
|
||||
|
||||
if remove_prefix_from_model:
|
||||
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(prefix)]
|
||||
expected_keys = [".".join(s.split(".")[1:]) if s.startswith(prefix) else s for s in expected_keys]
|
||||
_prefix = f"{prefix}."
|
||||
expected_keys_not_prefixed = [s for s in expected_keys if not s.startswith(_prefix)]
|
||||
expected_keys = [s[len(_prefix) :] if s.startswith(_prefix) else s for s in expected_keys]
|
||||
elif add_prefix_to_model:
|
||||
expected_keys = [".".join([prefix, s]) for s in expected_keys]
|
||||
|
||||
@ -2641,13 +2642,16 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
|
||||
# torch.nn.ParameterList is a special case where two parameter keywords
|
||||
# are appended to the module name, *e.g.* bert.special_embeddings.0
|
||||
module_keys = module_keys.union(set([".".join(key.split(".")[:-2]) for key in names if key[-1].isdigit()]))
|
||||
module_keys = module_keys.union(
|
||||
set([".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()])
|
||||
)
|
||||
|
||||
retrieved_modules = []
|
||||
# retrieve all modules that has at least one missing weight name
|
||||
for name, module in self.named_modules():
|
||||
if remove_prefix:
|
||||
name = ".".join(name.split(".")[1:]) if name.startswith(self.base_model_prefix) else name
|
||||
_prefix = f"{self.base_model_prefix}."
|
||||
name = name[len(_prefix) :] if name.startswith(_prefix) else name
|
||||
elif add_prefix:
|
||||
name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
|
||||
|
||||
|
@ -762,6 +762,12 @@ class AlbertModel(AlbertPreTrainedModel):
|
||||
ALBERT_START_DOCSTRING,
|
||||
)
|
||||
class AlbertForPreTraining(AlbertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"predictions.decoder.weight",
|
||||
"predictions.decoder.bias",
|
||||
"embeddings.position_ids",
|
||||
]
|
||||
|
||||
def __init__(self, config: AlbertConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -910,6 +916,11 @@ class AlbertSOPHead(nn.Module):
|
||||
class AlbertForMaskedLM(AlbertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"predictions.decoder.weight",
|
||||
"predictions.decoder.bias",
|
||||
"embeddings.position_ids",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1153,6 +1153,8 @@ class BartDecoder(BartPretrainedModel):
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class BartModel(BartPretrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1281,7 +1283,12 @@ class BartModel(BartPretrainedModel):
|
||||
)
|
||||
class BartForConditionalGeneration(BartPretrainedModel):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
r"lm_head.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: BartConfig):
|
||||
super().__init__(config)
|
||||
@ -1451,6 +1458,8 @@ class BartForConditionalGeneration(BartPretrainedModel):
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class BartForSequenceClassification(BartPretrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BartConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = BartModel(config)
|
||||
@ -1578,6 +1587,8 @@ class BartForSequenceClassification(BartPretrainedModel):
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class BartForQuestionAnswering(BartPretrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1714,6 +1725,8 @@ class BartDecoderWrapper(BartPretrainedModel):
|
||||
BART_START_DOCSTRING,
|
||||
)
|
||||
class BartForCausalLM(BartPretrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -1047,6 +1047,8 @@ class BertModel(BertPreTrainedModel):
|
||||
BERT_START_DOCSTRING,
|
||||
)
|
||||
class BertForPreTraining(BertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1153,7 +1155,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
class BertLMHeadModel(BertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1288,7 +1290,7 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
class BertForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", r"cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -855,6 +855,8 @@ class BertGenerationOnlyLMHead(nn.Module):
|
||||
BERT_GENERATION_START_DOCSTRING,
|
||||
)
|
||||
class BertGenerationDecoder(BertGenerationPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.decoder.weight", "lm_head.decoder.bias", "embeddings.position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -2262,6 +2262,8 @@ class BigBirdModel(BigBirdPreTrainedModel):
|
||||
|
||||
|
||||
class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -2366,6 +2368,8 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""BigBird Model with a `language modeling` head on top.""", BIG_BIRD_START_DOCSTRING)
|
||||
class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -2508,8 +2512,12 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
"""BigBird Model with a `language modeling` head on top for CLM fine-tuning.""", BIG_BIRD_START_DOCSTRING
|
||||
)
|
||||
class BigBirdForCausalLM(BigBirdPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"position_ids",
|
||||
r"predictions.decoder.bias",
|
||||
"cls.predictions.decoder.weight",
|
||||
"cls.predictions.decoder.bias",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -2350,6 +2350,8 @@ class BigBirdPegasusDecoder(BigBirdPegasusPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.bart.modeling_bart.BartModel with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
|
||||
class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -2480,7 +2482,12 @@ class BigBirdPegasusModel(BigBirdPegasusPreTrainedModel):
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForConditionalGeneration with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
|
||||
class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"final_logits_bias",
|
||||
r"lm_head.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig):
|
||||
super().__init__(config)
|
||||
@ -2651,6 +2658,8 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForSequenceClassification with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
|
||||
class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BigBirdPegasusConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = BigBirdPegasusModel(config)
|
||||
@ -2779,6 +2788,8 @@ class BigBirdPegasusForSequenceClassification(BigBirdPegasusPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForQuestionAnswering with Bart->BigBirdPegasus, BART->BIGBIRD_PEGASUS
|
||||
class BigBirdPegasusForQuestionAnswering(BigBirdPegasusPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -2910,6 +2921,8 @@ class BigBirdPegasusDecoderWrapper(BigBirdPegasusPreTrainedModel):
|
||||
|
||||
|
||||
class BigBirdPegasusForCausalLM(BigBirdPegasusPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -1087,6 +1087,8 @@ class BlenderbotDecoder(BlenderbotPreTrainedModel):
|
||||
BLENDERBOT_START_DOCSTRING,
|
||||
)
|
||||
class BlenderbotModel(BlenderbotPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1231,6 +1233,8 @@ class BlenderbotForConditionalGeneration(BlenderbotPreTrainedModel):
|
||||
r"encoder.version",
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: BlenderbotConfig):
|
||||
@ -1420,6 +1424,8 @@ class BlenderbotDecoderWrapper(BlenderbotPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Blenderbot, facebook/bart-base->facebook/blenderbot-400M-distill
|
||||
class BlenderbotForCausalLM(BlenderbotPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -1081,6 +1081,8 @@ class BlenderbotSmallDecoder(BlenderbotSmallPreTrainedModel):
|
||||
BLENDERBOT_SMALL_START_DOCSTRING,
|
||||
)
|
||||
class BlenderbotSmallModel(BlenderbotSmallPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1213,6 +1215,8 @@ class BlenderbotSmallForConditionalGeneration(BlenderbotSmallPreTrainedModel):
|
||||
r"encoder.version",
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: BlenderbotSmallConfig):
|
||||
@ -1387,6 +1391,8 @@ class BlenderbotSmallDecoderWrapper(BlenderbotSmallPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->BlenderbotSmall, facebook/bart-base->facebook/blenderbot_small-90M
|
||||
class BlenderbotSmallForCausalLM(BlenderbotSmallPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -763,6 +763,8 @@ CONVBERT_INPUTS_DOCSTRING = r"""
|
||||
CONVBERT_START_DOCSTRING,
|
||||
)
|
||||
class ConvBertModel(ConvBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.embeddings = ConvBertEmbeddings(config)
|
||||
@ -877,6 +879,8 @@ class ConvBertGeneratorPredictions(nn.Module):
|
||||
|
||||
@add_start_docstrings("""ConvBERT Model with a `language modeling` head on top.""", CONVBERT_START_DOCSTRING)
|
||||
class ConvBertForMaskedLM(ConvBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["embeddings.position_ids", "generator.lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -987,6 +991,8 @@ class ConvBertClassificationHead(nn.Module):
|
||||
CONVBERT_START_DOCSTRING,
|
||||
)
|
||||
class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@ -1083,6 +1089,8 @@ class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
|
||||
CONVBERT_START_DOCSTRING,
|
||||
)
|
||||
class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1177,6 +1185,8 @@ class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
|
||||
CONVBERT_START_DOCSTRING,
|
||||
)
|
||||
class ConvBertForTokenClassification(ConvBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.num_labels = config.num_labels
|
||||
@ -1259,6 +1269,8 @@ class ConvBertForTokenClassification(ConvBertPreTrainedModel):
|
||||
CONVBERT_START_DOCSTRING,
|
||||
)
|
||||
class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["embeddings.position_ids"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -509,6 +509,8 @@ class CTRLModel(CTRLPreTrainedModel):
|
||||
CTRL_START_DOCSTRING,
|
||||
)
|
||||
class CTRLLMHeadModel(CTRLPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = CTRLModel(config)
|
||||
|
@ -1038,7 +1038,7 @@ class DebertaModel(DebertaPreTrainedModel):
|
||||
@add_start_docstrings("""DeBERTa Model with a `language modeling` head on top.""", DEBERTA_START_DOCSTRING)
|
||||
class DebertaForMaskedLM(DebertaPreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1139,7 +1139,7 @@ class DebertaV2Model(DebertaV2PreTrainedModel):
|
||||
# Copied from transformers.models.deberta.modeling_deberta.DebertaForMaskedLM with Deberta->DebertaV2
|
||||
class DebertaV2ForMaskedLM(DebertaV2PreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1788,6 +1788,9 @@ class DeformableDetrModel(DeformableDetrPreTrainedModel):
|
||||
DEFORMABLE_DETR_START_DOCSTRING,
|
||||
)
|
||||
class DeformableDetrForObjectDetection(DeformableDetrPreTrainedModel):
|
||||
# When using clones, all layers > 0 will be clones, but layer 0 *is* required
|
||||
_keys_to_ignore_on_load_missing = ["bbox_embed\.[1-9]\d*", "class_embed\.[1-9]\d*"]
|
||||
|
||||
def __init__(self, config: DeformableDetrConfig):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -579,6 +579,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
|
||||
DISTILBERT_START_DOCSTRING,
|
||||
)
|
||||
class DistilBertForMaskedLM(DistilBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["vocab_projector.weight"]
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -1161,6 +1161,8 @@ class ElectraForPreTraining(ElectraPreTrainedModel):
|
||||
ELECTRA_START_DOCSTRING,
|
||||
)
|
||||
class ElectraForMaskedLM(ElectraPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1530,6 +1532,8 @@ class ElectraForMultipleChoice(ElectraPreTrainedModel):
|
||||
"""ELECTRA Model with a `language modeling` head on top for CLM fine-tuning.""", ELECTRA_START_DOCSTRING
|
||||
)
|
||||
class ElectraForCausalLM(ElectraPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["generator_lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -977,6 +977,8 @@ class ErnieModel(ErniePreTrainedModel):
|
||||
ERNIE_START_DOCSTRING,
|
||||
)
|
||||
class ErnieForPreTraining(ErniePreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.__init__ with Bert->Ernie,bert->ernie
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1087,7 +1089,7 @@ class ErnieForPreTraining(ErniePreTrainedModel):
|
||||
)
|
||||
class ErnieForCausalLM(ErniePreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->ErnieForCausalLM,Bert->Ernie,bert->ernie
|
||||
def __init__(self, config):
|
||||
@ -1228,7 +1230,7 @@ class ErnieForCausalLM(ErniePreTrainedModel):
|
||||
@add_start_docstrings("""Ernie Model with a `language modeling` head on top.""", ERNIE_START_DOCSTRING)
|
||||
class ErnieForMaskedLM(ErniePreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->Ernie,bert->ernie
|
||||
def __init__(self, config):
|
||||
|
@ -896,7 +896,7 @@ class EsmModel(EsmPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""ESM Model with a `language modeling` head on top.""", ESM_START_DOCSTRING)
|
||||
class EsmForMaskedLM(EsmPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", "lm_head.decoder.weight"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -657,6 +657,8 @@ class FlaubertModel(FlaubertPreTrainedModel):
|
||||
)
|
||||
# Copied transformers.models.xlm.modeling_xlm.XLMWithLMHeadModel with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
|
||||
class FlaubertWithLMHeadModel(FlaubertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = FlaubertModel(config)
|
||||
|
@ -1729,6 +1729,14 @@ class FlavaGlobalContrastiveHead(nn.Module):
|
||||
FLAVA_START_DOCSTRING.format(config="FlavaConfig") + FLAVA_PRETRAINING_START_DOCSTRING_EXTRA,
|
||||
)
|
||||
class FlavaForPreTraining(FlavaPreTrainedModel):
|
||||
# Those are linked to xxx.bias
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"mmm_text_head.decoder.bias",
|
||||
"mmm_image_head.decoder.bias",
|
||||
"mlm_head.decoder.bias",
|
||||
"mim_head.decoder.bias",
|
||||
]
|
||||
|
||||
def __init__(self, config: FlavaConfig, image_codebook: Optional[nn.Module] = None):
|
||||
super().__init__(config)
|
||||
self.flava = FlavaModel(config)
|
||||
|
@ -624,6 +624,8 @@ class FNetModel(FNetPreTrainedModel):
|
||||
FNET_START_DOCSTRING,
|
||||
)
|
||||
class FNetForPreTraining(FNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -716,6 +718,8 @@ class FNetForPreTraining(FNetPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""FNet Model with a `language modeling` head on top.""", FNET_START_DOCSTRING)
|
||||
class FNetForMaskedLM(FNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -992,6 +992,8 @@ def _get_shape(t):
|
||||
FSMT_START_DOCSTRING,
|
||||
)
|
||||
class FSMTModel(PretrainedFSMTModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.output_projection.weight"]
|
||||
|
||||
def __init__(self, config: FSMTConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1120,6 +1122,7 @@ class FSMTForConditionalGeneration(PretrainedFSMTModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
"model.decoder.embed_positions.weight",
|
||||
"decoder.output_projection.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
"model.encoder.embed_positions.weight",
|
||||
|
@ -1193,6 +1193,8 @@ class FunnelForPreTraining(FunnelPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Funnel Transformer Model with a `language modeling` head on top.""", FUNNEL_START_DOCSTRING)
|
||||
class FunnelForMaskedLM(FunnelPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: FunnelConfig) -> None:
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -592,7 +592,7 @@ class GPTNeoXJapaneseModel(GPTNeoXJapanesePreTrainedModel):
|
||||
)
|
||||
class GPTNeoXJapaneseForCausalLM(GPTNeoXJapanesePreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "embed_out.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -856,7 +856,7 @@ class IBertModel(IBertPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""I-BERT Model with a `language modeling` head on top.""", IBERT_START_DOCSTRING)
|
||||
class IBertForMaskedLM(IBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder.bias", "lm_head.decoder.weight"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -849,6 +849,12 @@ class LayoutLMModel(LayoutLMPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""LayoutLM Model with a `language modeling` head on top.""", LAYOUTLM_START_DOCSTRING)
|
||||
class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"cls.predictions.decoder.bias",
|
||||
"cls.predictions.decoder.weight",
|
||||
"embeddings.position_ids",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -2212,6 +2212,8 @@ class LEDDecoder(LEDPreTrainedModel):
|
||||
LED_START_DOCSTRING,
|
||||
)
|
||||
class LEDModel(LEDPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: LEDConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -2341,6 +2343,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
||||
r"encoder.version",
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: LEDConfig):
|
||||
@ -2528,6 +2532,8 @@ class LEDForConditionalGeneration(LEDPreTrainedModel):
|
||||
LED_START_DOCSTRING,
|
||||
)
|
||||
class LEDForSequenceClassification(LEDPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: LEDConfig, **kwargs):
|
||||
warnings.warn(
|
||||
"The `transformers.LEDForSequenceClassification` class is deprecated and will be removed in version 5 of"
|
||||
@ -2662,6 +2668,8 @@ class LEDForSequenceClassification(LEDPreTrainedModel):
|
||||
LED_START_DOCSTRING,
|
||||
)
|
||||
class LEDForQuestionAnswering(LEDPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -1775,7 +1775,7 @@ class LongformerModel(LongformerPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Longformer Model with a `language modeling` head on top.""", LONGFORMER_START_DOCSTRING)
|
||||
class LongformerForMaskedLM(LongformerPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.decoder"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -2137,9 +2137,7 @@ class LongT5ForConditionalGeneration(LongT5PreTrainedModel):
|
||||
LONGT5_START_DOCSTRING,
|
||||
)
|
||||
class LongT5EncoderModel(LongT5PreTrainedModel):
|
||||
authorized_missing_keys = [
|
||||
r"encoder.embed_tokens.weight",
|
||||
]
|
||||
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: LongT5Config):
|
||||
super().__init__(config)
|
||||
|
@ -1023,6 +1023,8 @@ class LxmertModel(LxmertPreTrainedModel):
|
||||
LXMERT_START_DOCSTRING,
|
||||
)
|
||||
class LxmertForPreTraining(LxmertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
# Configuration
|
||||
|
@ -1128,6 +1128,8 @@ class M2M100Decoder(M2M100PreTrainedModel):
|
||||
M2M_100_START_DOCSTRING,
|
||||
)
|
||||
class M2M100Model(M2M100PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: M2M100Config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1244,12 +1246,8 @@ class M2M100ForConditionalGeneration(M2M100PreTrainedModel):
|
||||
r"encoder.version",
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
r"model.encoder.embed_positions.weights",
|
||||
r"model.decoder.embed_positions.weights",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
r"model.encoder.embed_positions.weights",
|
||||
r"model.decoder.embed_positions.weights",
|
||||
r"encoder.embed_tokens.weight",
|
||||
r"decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: M2M100Config):
|
||||
|
@ -1087,6 +1087,8 @@ class MarianDecoder(MarianPreTrainedModel):
|
||||
"The bare Marian Model outputting raw hidden-states without any specific head on top.", MARIAN_START_DOCSTRING
|
||||
)
|
||||
class MarianModel(MarianPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: MarianConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1278,6 +1280,8 @@ class MarianMTModel(MarianPreTrainedModel):
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
r"embed_positions",
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
_keys_to_ignore_on_save = ["model.encoder.embed_positions.weight", "model.decoder.embed_positions.weight"]
|
||||
@ -1540,6 +1544,8 @@ class MarianDecoderWrapper(MarianPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->Marian, facebook/bart-base->Helsinki-NLP/opus-mt-fr-en
|
||||
class MarianForCausalLM(MarianPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -1150,6 +1150,8 @@ class MBartDecoder(MBartPreTrainedModel):
|
||||
MBART_START_DOCSTRING,
|
||||
)
|
||||
class MBartModel(MBartPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: MBartConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1273,6 +1275,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
||||
r"encoder.version",
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: MBartConfig):
|
||||
@ -1440,6 +1444,8 @@ class MBartForConditionalGeneration(MBartPreTrainedModel):
|
||||
MBART_START_DOCSTRING,
|
||||
)
|
||||
class MBartForSequenceClassification(MBartPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: MBartConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = MBartModel(config)
|
||||
@ -1568,6 +1574,8 @@ class MBartForSequenceClassification(MBartPreTrainedModel):
|
||||
MBART_START_DOCSTRING,
|
||||
)
|
||||
class MBartForQuestionAnswering(MBartPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1701,6 +1709,8 @@ class MBartDecoderWrapper(MBartPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->MBart, facebook/bart-base->facebook/mbart-large-cc25
|
||||
class MBartForCausalLM(MBartPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -1009,6 +1009,8 @@ class MegatronBertModel(MegatronBertPreTrainedModel):
|
||||
MEGATRON_BERT_START_DOCSTRING,
|
||||
)
|
||||
class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
|
||||
|
||||
def __init__(self, config, add_binary_head=True):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1115,7 +1117,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
|
||||
class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"cls.predictions.decoder"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1261,7 +1263,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
|
||||
class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler", r"seq_relationship"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -925,6 +925,12 @@ class MobileBertModel(MobileBertPreTrainedModel):
|
||||
MOBILEBERT_START_DOCSTRING,
|
||||
)
|
||||
class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"cls.predictions.decoder.weight",
|
||||
"cls.predictions.decoder.bias",
|
||||
"embeddings.position_ids",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.mobilebert = MobileBertModel(config)
|
||||
@ -1033,6 +1039,11 @@ class MobileBertForPreTraining(MobileBertPreTrainedModel):
|
||||
class MobileBertForMaskedLM(MobileBertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"cls.predictions.decoder.weight",
|
||||
"cls.predictions.decoder.bias",
|
||||
"embeddings.position_ids",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -574,7 +574,7 @@ class MPNetModel(MPNetPreTrainedModel):
|
||||
|
||||
|
||||
class MPNetForMaskedLM(MPNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"lm_head.decoder"]
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
|
||||
def __init__(self, config):
|
||||
|
@ -1292,6 +1292,7 @@ class MvpDecoder(MvpPreTrainedModel):
|
||||
)
|
||||
class MvpModel(MvpPreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: MvpConfig):
|
||||
super().__init__(config)
|
||||
@ -1429,6 +1430,8 @@ class MvpModel(MvpPreTrainedModel):
|
||||
"The MVP Model with a language modeling head. Can be used for various text generation tasks.", MVP_START_DOCSTRING
|
||||
)
|
||||
class MvpForConditionalGeneration(MvpPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
|
||||
def __init__(self, config: MvpConfig):
|
||||
super().__init__(config)
|
||||
self.model = MvpModel(config)
|
||||
@ -1600,6 +1603,7 @@ class MvpForConditionalGeneration(MvpPreTrainedModel):
|
||||
)
|
||||
class MvpForSequenceClassification(MvpPreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
|
||||
|
||||
def __init__(self, config: MvpConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
@ -1727,6 +1731,7 @@ class MvpForSequenceClassification(MvpPreTrainedModel):
|
||||
)
|
||||
class MvpForQuestionAnswering(MvpPreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"final_logits_bias", r"lm_head.weight"]
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
@ -1856,6 +1861,8 @@ class MvpDecoderWrapper(MvpPreTrainedModel):
|
||||
|
||||
|
||||
class MvpForCausalLM(MvpPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -1038,6 +1038,8 @@ class NezhaModel(NezhaPreTrainedModel):
|
||||
NEZHA_START_DOCSTRING,
|
||||
)
|
||||
class NezhaForPreTraining(NezhaPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1140,7 +1142,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel):
|
||||
class NezhaForMaskedLM(NezhaPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", r"positions_encoding"]
|
||||
_keys_to_ignore_on_load_missing = [r"cls.predictions.decoder", r"positions_encoding"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -660,6 +660,8 @@ class NystromformerModel(NystromformerPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Nyströmformer Model with a `language modeling` head on top.""", NYSTROMFORMER_START_DOCSTRING)
|
||||
class NystromformerForMaskedLM(NystromformerPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -531,6 +531,8 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
|
||||
OPENAI_GPT_START_DOCSTRING,
|
||||
)
|
||||
class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = OpenAIGPTModel(config)
|
||||
@ -621,6 +623,8 @@ input sequence).
|
||||
OPENAI_GPT_START_DOCSTRING,
|
||||
)
|
||||
class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -1140,6 +1140,8 @@ class PegasusDecoder(PegasusPreTrainedModel):
|
||||
PEGASUS_START_DOCSTRING,
|
||||
)
|
||||
class PegasusModel(PegasusPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: PegasusConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1296,6 +1298,8 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel):
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
r"embed_positions.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: PegasusConfig):
|
||||
@ -1496,6 +1500,8 @@ class PegasusDecoderWrapper(PegasusPreTrainedModel):
|
||||
|
||||
|
||||
class PegasusForCausalLM(PegasusPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -1375,6 +1375,8 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
|
||||
PEGASUS_X_START_DOCSTRING,
|
||||
)
|
||||
class PegasusXModel(PegasusXPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: PegasusXConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1522,6 +1524,8 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel):
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
r"embed_positions.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: PegasusXConfig):
|
||||
|
@ -1125,6 +1125,8 @@ class PLBartDecoder(PLBartPreTrainedModel):
|
||||
PLBART_START_DOCSTRING,
|
||||
)
|
||||
class PLBartModel(PLBartPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.embed_tokens.weight", "encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: PLBartConfig):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1247,6 +1249,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
||||
r"encoder.version",
|
||||
r"decoder.version",
|
||||
r"lm_head.weight",
|
||||
"decoder.embed_tokens.weight",
|
||||
"encoder.embed_tokens.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: PLBartConfig):
|
||||
@ -1411,6 +1415,8 @@ class PLBartForConditionalGeneration(PLBartPreTrainedModel):
|
||||
PLBART_START_DOCSTRING,
|
||||
)
|
||||
class PLBartForSequenceClassification(PLBartPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: PLBartConfig, **kwargs):
|
||||
super().__init__(config, **kwargs)
|
||||
self.model = PLBartModel(config)
|
||||
@ -1548,6 +1554,8 @@ class PLBartDecoderWrapper(PLBartPreTrainedModel):
|
||||
|
||||
# Copied from transformers.models.bart.modeling_bart.BartForCausalLM with Bart->PLBart, facebook/bart-base->uclanlp/plbart-base
|
||||
class PLBartForCausalLM(PLBartPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -859,11 +859,7 @@ class ProphetNetNgramSelfAttention(nn.Module):
|
||||
):
|
||||
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
|
||||
|
||||
assert list(hidden_states.size()) == [
|
||||
batch_size,
|
||||
ngram_sequence_length,
|
||||
hidden_size,
|
||||
], (
|
||||
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
|
||||
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
|
||||
f" {hidden_states.shape}"
|
||||
)
|
||||
@ -1774,6 +1770,8 @@ class ProphetNetDecoder(ProphetNetPreTrainedModel):
|
||||
PROPHETNET_START_DOCSTRING,
|
||||
)
|
||||
class ProphetNetModel(ProphetNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
|
||||
|
||||
def __init__(self, config: ProphetNetConfig):
|
||||
super().__init__(config)
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
@ -1901,6 +1899,12 @@ class ProphetNetModel(ProphetNetPreTrainedModel):
|
||||
PROPHETNET_START_DOCSTRING,
|
||||
)
|
||||
class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"decoder.word_embeddings.weight",
|
||||
"encoder.word_embeddings.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: ProphetNetConfig):
|
||||
super().__init__(config)
|
||||
self.prophetnet = ProphetNetModel(config)
|
||||
@ -2111,6 +2115,8 @@ class ProphetNetForConditionalGeneration(ProphetNetPreTrainedModel):
|
||||
PROPHETNET_START_DOCSTRING,
|
||||
)
|
||||
class ProphetNetForCausalLM(ProphetNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: ProphetNetConfig):
|
||||
# set config for CLM
|
||||
config = copy.deepcopy(config)
|
||||
|
@ -1140,6 +1140,8 @@ class RealmBertModel(RealmPreTrainedModel):
|
||||
REALM_START_DOCSTRING,
|
||||
)
|
||||
class RealmEmbedder(RealmPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1368,6 +1370,8 @@ class RealmScorer(RealmPreTrainedModel):
|
||||
REALM_START_DOCSTRING,
|
||||
)
|
||||
class RealmKnowledgeAugEncoder(RealmPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.realm = RealmBertModel(self.config)
|
||||
|
@ -2192,6 +2192,8 @@ class ReformerModel(ReformerPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Reformer Model with a `language modeling` head on top.""", REFORMER_START_DOCSTRING)
|
||||
class ReformerModelWithLMHead(ReformerPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
assert config.is_decoder, "If you want to use `ReformerModelWithLMHead` make sure that `is_decoder=True`."
|
||||
|
@ -1051,6 +1051,8 @@ class RoCBertModel(RoCBertPreTrainedModel):
|
||||
ROC_BERT_START_DOCSTRING,
|
||||
)
|
||||
class RoCBertForPreTraining(RoCBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1235,7 +1237,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
|
||||
@add_start_docstrings("""RoCBert Model with a `language modeling` head on top.""", ROC_BERT_START_DOCSTRING)
|
||||
class RoCBertForMaskedLM(RoCBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.__init__ with Bert->RoCBert,bert->roc_bert
|
||||
def __init__(self, config):
|
||||
@ -1361,7 +1363,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
|
||||
)
|
||||
class RoCBertForCausalLM(RoCBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.__init__ with BertLMHeadModel->RoCBertForCausalLM,Bert->RoCBert,bert->roc_bert
|
||||
def __init__(self, config):
|
||||
|
@ -954,6 +954,8 @@ class RoFormerModel(RoFormerPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""RoFormer Model with a `language modeling` head on top.""", ROFORMER_START_DOCSTRING)
|
||||
class RoFormerForMaskedLM(RoFormerPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1055,8 +1057,7 @@ class RoFormerForMaskedLM(RoFormerPreTrainedModel):
|
||||
"""RoFormer Model with a `language modeling` head on top for CLM fine-tuning.""", ROFORMER_START_DOCSTRING
|
||||
)
|
||||
class RoFormerForCausalLM(RoFormerPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1256,6 +1256,7 @@ class Speech2TextForConditionalGeneration(Speech2TextPreTrainedModel):
|
||||
r"decoder.version",
|
||||
r"model.encoder.embed_positions.weights",
|
||||
r"model.decoder.embed_positions.weights",
|
||||
r"lm_head.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
r"model.encoder.embed_positions.weights",
|
||||
|
@ -745,6 +745,8 @@ class Speech2Text2DecoderWrapper(Speech2Text2PreTrainedModel):
|
||||
SPEECH_TO_TEXT_2_START_DOCSTRING,
|
||||
)
|
||||
class Speech2Text2ForCausalLM(Speech2Text2PreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -648,7 +648,11 @@ class SqueezeBertModel(SqueezeBertPreTrainedModel):
|
||||
@add_start_docstrings("""SqueezeBERT Model with a `language modeling` head on top.""", SQUEEZEBERT_START_DOCSTRING)
|
||||
class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
|
||||
|
||||
_keys_to_ignore_on_load_missing = [r"predictions.decoder.bias"]
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"predictions.decoder.bias",
|
||||
"cls.predictions.decoder.weight",
|
||||
"embeddings.position_ids",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
@ -1758,9 +1758,7 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
|
||||
T5_START_DOCSTRING,
|
||||
)
|
||||
class T5EncoderModel(T5PreTrainedModel):
|
||||
authorized_missing_keys = [
|
||||
r"encoder.embed_tokens.weight",
|
||||
]
|
||||
_keys_to_ignore_on_load_missing = [r"encoder.embed_tokens.weight"]
|
||||
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__(config)
|
||||
|
@ -1004,6 +1004,7 @@ class TapasModel(TapasPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""Tapas Model with a `language modeling` head on top.""", TAPAS_START_DOCSTRING)
|
||||
class TapasForMaskedLM(TapasPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
config_class = TapasConfig
|
||||
base_model_prefix = "tapas"
|
||||
|
||||
|
@ -1006,6 +1006,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
|
||||
TRANSFO_XL_START_DOCSTRING,
|
||||
)
|
||||
class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"crit\.out_projs\.\d+", r"crit\.out_layers\.\d+\.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = TransfoXLModel(config)
|
||||
|
@ -785,6 +785,8 @@ class TrOCRDecoderWrapper(TrOCRPreTrainedModel):
|
||||
TROCR_START_DOCSTRING,
|
||||
)
|
||||
class TrOCRForCausalLM(TrOCRPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["output_projection.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
config = copy.deepcopy(config)
|
||||
config.is_decoder = True
|
||||
|
@ -890,6 +890,8 @@ class ViltPooler(nn.Module):
|
||||
VILT_START_DOCSTRING,
|
||||
)
|
||||
class ViltForMaskedLM(ViltPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["mlm_score.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -871,6 +871,8 @@ class VisualBertModel(VisualBertPreTrainedModel):
|
||||
VISUAL_BERT_START_DOCSTRING,
|
||||
)
|
||||
class VisualBertForPreTraining(VisualBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
@ -1459,6 +1461,8 @@ class VisualBertRegionToPhraseAttention(nn.Module):
|
||||
VISUAL_BERT_START_DOCSTRING,
|
||||
)
|
||||
class VisualBertForRegionToPhraseAlignment(VisualBertPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["cls.predictions.decoder.bias"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -825,6 +825,7 @@ class XGLMForCausalLM(XGLMPreTrainedModel):
|
||||
base_model_prefix = "model"
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
r"model.embed_positions.weights",
|
||||
r"embed_positions.weights",
|
||||
r"lm_head.weight",
|
||||
]
|
||||
_keys_to_ignore_on_save = [
|
||||
|
@ -673,6 +673,8 @@ class XLMPredLayer(nn.Module):
|
||||
XLM_START_DOCSTRING,
|
||||
)
|
||||
class XLMWithLMHeadModel(XLMPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["pred_layer.proj.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.transformer = XLMModel(config)
|
||||
|
@ -876,11 +876,7 @@ class XLMProphetNetNgramSelfAttention(nn.Module):
|
||||
):
|
||||
batch_size, ngram_sequence_length, hidden_size = hidden_states.size()
|
||||
|
||||
assert list(hidden_states.size()) == [
|
||||
batch_size,
|
||||
ngram_sequence_length,
|
||||
hidden_size,
|
||||
], (
|
||||
assert list(hidden_states.size()) == [batch_size, ngram_sequence_length, hidden_size], (
|
||||
f"`hidden_states` should be of shape {batch_size, ngram_sequence_length, hidden_size}, but is of shape"
|
||||
f" {hidden_states.shape}"
|
||||
)
|
||||
@ -1798,6 +1794,8 @@ class XLMProphetNetDecoder(XLMProphetNetPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetModel with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
|
||||
class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["decoder.word_embeddings.weight", "encoder.word_embeddings.weight"]
|
||||
|
||||
def __init__(self, config: XLMProphetNetConfig):
|
||||
super().__init__(config)
|
||||
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
|
||||
@ -1926,6 +1924,12 @@ class XLMProphetNetModel(XLMProphetNetPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForConditionalGeneration with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
|
||||
class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"decoder.word_embeddings.weight",
|
||||
"encoder.word_embeddings.weight",
|
||||
"lm_head.weight",
|
||||
]
|
||||
|
||||
def __init__(self, config: XLMProphetNetConfig):
|
||||
super().__init__(config)
|
||||
self.prophetnet = XLMProphetNetModel(config)
|
||||
@ -2139,6 +2143,8 @@ class XLMProphetNetForConditionalGeneration(XLMProphetNetPreTrainedModel):
|
||||
)
|
||||
# Copied from transformers.models.prophetnet.modeling_prophetnet.ProphetNetForCausalLM with microsoft/prophetnet-large-uncased->patrickvonplaten/xprophetnet-large-uncased-standalone, ProphetNet->XLMProphetNet, PROPHETNET->XLM_PROPHETNET
|
||||
class XLMProphetNetForCausalLM(XLMProphetNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = ["lm_head.weight"]
|
||||
|
||||
def __init__(self, config: XLMProphetNetConfig):
|
||||
# set config for CLM
|
||||
config = copy.deepcopy(config)
|
||||
|
@ -1296,6 +1296,8 @@ class XLNetModel(XLNetPreTrainedModel):
|
||||
XLNET_START_DOCSTRING,
|
||||
)
|
||||
class XLNetLMHeadModel(XLNetPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [r"lm_loss.weight"]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
self.attn_type = config.attn_type
|
||||
|
@ -852,6 +852,12 @@ class YosoModel(YosoPreTrainedModel):
|
||||
|
||||
@add_start_docstrings("""YOSO Model with a `language modeling` head on top.""", YOSO_START_DOCSTRING)
|
||||
class YosoForMaskedLM(YosoPreTrainedModel):
|
||||
_keys_to_ignore_on_load_missing = [
|
||||
"cls.predictions.decoder.bias",
|
||||
"cls.predictions.decoder.weight",
|
||||
"embeddings.position_ids",
|
||||
]
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
|
@ -119,8 +119,6 @@ class AutoModelTest(unittest.TestCase):
|
||||
self.assertIsNotNone(model)
|
||||
self.assertIsInstance(model, BertForPreTraining)
|
||||
# Only one value should not be initialized and in the missing keys.
|
||||
missing_keys = loading_info.pop("missing_keys")
|
||||
self.assertListEqual(["cls.predictions.decoder.bias"], missing_keys)
|
||||
for key, value in loading_info.items():
|
||||
self.assertEqual(len(value), 0)
|
||||
|
||||
|
@ -424,7 +424,6 @@ class BartModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
|
||||
is_encoder_decoder = True
|
||||
fx_compatible = True
|
||||
test_pruning = False
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(self):
|
||||
self.model_tester = BartModelTester(self)
|
||||
@ -1445,6 +1444,7 @@ class BartStandaloneDecoderModelTest(ModelTesterMixin, GenerationTesterMixin, un
|
||||
fx_comptatible = True
|
||||
test_pruning = False
|
||||
is_encoder_decoder = False
|
||||
test_missing_keys = False
|
||||
|
||||
def setUp(
|
||||
self,
|
||||
|
@ -1468,11 +1468,24 @@ class ModelTesterMixin:
|
||||
base_model_prefix = model.base_model_prefix
|
||||
|
||||
if hasattr(model, base_model_prefix):
|
||||
|
||||
extra_params = {k: v for k, v in model.named_parameters() if not k.startswith(base_model_prefix)}
|
||||
extra_params.update({k: v for k, v in model.named_buffers() if not k.startswith(base_model_prefix)})
|
||||
# Some models define this as None
|
||||
if model._keys_to_ignore_on_load_missing:
|
||||
for key in model._keys_to_ignore_on_load_missing:
|
||||
extra_params.pop(key, None)
|
||||
|
||||
if not extra_params:
|
||||
# In that case, we *are* on a head model, but every
|
||||
# single key is not actual parameters and this is
|
||||
# tested in `test_tied_model_weights_key_ignore` test.
|
||||
continue
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir_name:
|
||||
model.base_model.save_pretrained(temp_dir_name)
|
||||
model, loading_info = model_class.from_pretrained(temp_dir_name, output_loading_info=True)
|
||||
with self.subTest(msg=f"Missing keys for {model.__class__.__name__}"):
|
||||
self.assertGreater(len(loading_info["missing_keys"]), 0)
|
||||
self.assertGreater(len(loading_info["missing_keys"]), 0, model.__class__.__name__)
|
||||
|
||||
def test_tie_model_weights(self):
|
||||
if not self.test_torchscript:
|
||||
@ -1522,6 +1535,54 @@ class ModelTesterMixin:
|
||||
# self.assertTrue(model.transformer.wte.weight.shape, model.lm_head.weight.shape)
|
||||
# self.assertTrue(check_same_values(model.transformer.wte, model.lm_head))
|
||||
|
||||
def test_tied_model_weights_key_ignore(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
for model_class in self.all_model_classes:
|
||||
model_tied = model_class(config)
|
||||
with tempfile.TemporaryDirectory() as d:
|
||||
model_tied.save_pretrained(d)
|
||||
|
||||
# We are nuking ALL weights on file, so every parameter should
|
||||
# yell on load. We're going to detect if we yell too much, or too little.
|
||||
with open(os.path.join(d, "pytorch_model.bin"), "wb") as f:
|
||||
torch.save({}, f)
|
||||
model_reloaded, infos = model_class.from_pretrained(d, output_loading_info=True)
|
||||
|
||||
# ! Actually we could use `state_dict()` and check iteratively the tensors which are the same (for instance using `tensor.data_ptr()`). to detect the duplicates.
|
||||
# ```python
|
||||
# model = GPT2LMHeadModel.from_pretrained("gpt2")
|
||||
# "lm_head.weight" in model.state_dict().keys() # True
|
||||
# "lm_head.weight" in model.named_parameters() # False
|
||||
# In [6]: model.lm_head.weight.data_ptr()
|
||||
# Out[6]: 139901378371648
|
||||
# In [9]: model.transformer.wte.weight.data_ptr()
|
||||
# Out[9]: 139901378371648 # Same PTR, it's the same DATA ! we would need to check for stride too to be 100% accurate.
|
||||
# ```
|
||||
|
||||
prefix = f"{model_reloaded.base_model_prefix}."
|
||||
params = dict(model_reloaded.named_parameters())
|
||||
params.update(dict(model_reloaded.named_buffers()))
|
||||
# param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
|
||||
param_names = set(k[len(prefix) :] if k.startswith(prefix) else k for k in params.keys())
|
||||
|
||||
missing_keys = set(infos["missing_keys"])
|
||||
|
||||
extra_missing = missing_keys - param_names
|
||||
# missed_missing = param_names - missing_keys
|
||||
|
||||
self.assertEqual(
|
||||
extra_missing,
|
||||
set(),
|
||||
f"This model {model_class.__name__} might be missing some `keys_to_ignore`: {extra_missing}",
|
||||
)
|
||||
|
||||
# self.assertEqual(
|
||||
# missed_missing,
|
||||
# set(),
|
||||
# f"This model {model_class.__name__} ignores keys {missed_missing} but they look like real"
|
||||
# " parameters",
|
||||
# )
|
||||
|
||||
def test_model_outputs_equivalence(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user