mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 10:12:23 +06:00
Add tie_weights() to LM heads and set bias in set_output_embeddings() (#28948)
* Add tie_weights() to LM heads and set bias in set_output_embeddings() The bias were not tied correctly in some LM heads, and this change should fix that. * Moving test_save_and_load_low_cpu_mem_usage to ModelTesterMixin * Adding _tie_weights() to MPNet and Vilt * Skip test for low cpu mem usage for Deta/DeformableDetr since they cannot init on meta device * Rename to test name to save_load to match the convention
This commit is contained in:
parent
3f4e79d29c
commit
725f4ad1cc
@ -692,6 +692,9 @@ class BertLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1062,6 +1065,7 @@ class BertForPreTraining(BertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=BertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1171,6 +1175,7 @@ class BertLMHeadModel(BertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@ -1324,6 +1329,7 @@ class BertForMaskedLM(BertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -1707,6 +1707,9 @@ class BigBirdLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -2266,6 +2269,7 @@ class BigBirdForPreTraining(BigBirdPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=BigBirdForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -2378,6 +2382,7 @@ class BigBirdForMaskedLM(BigBirdPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -2519,6 +2524,7 @@ class BigBirdForCausalLM(BigBirdPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIG_BIRD_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -523,6 +523,9 @@ class BlipTextLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -816,6 +819,7 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -608,6 +608,9 @@ class ErnieLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -995,6 +998,7 @@ class ErnieForPreTraining(ErniePreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=ErnieForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1109,6 +1113,7 @@ class ErnieForCausalLM(ErniePreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@ -1269,6 +1274,7 @@ class ErnieForMaskedLM(ErniePreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(ERNIE_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -589,6 +589,9 @@ class LayoutLMLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -869,6 +872,7 @@ class LayoutLMForMaskedLM(LayoutLMPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(LAYOUTLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
|
@ -318,6 +318,9 @@ class MarkupLMLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
|
@ -659,6 +659,9 @@ class MegatronBertLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1023,6 +1026,7 @@ class MegatronBertForPreTraining(MegatronBertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MegatronBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1132,6 +1136,7 @@ class MegatronBertForCausalLM(MegatronBertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1290,6 +1295,7 @@ class MegatronBertForMaskedLM(MegatronBertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(MEGATRON_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -587,6 +587,7 @@ class MPNetForMaskedLM(MPNetPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head.decoder = new_embeddings
|
||||
self.lm_head.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(MPNET_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
@ -659,6 +660,9 @@ class MPNetLMHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, features, **kwargs):
|
||||
x = self.dense(features)
|
||||
x = gelu(x)
|
||||
|
@ -820,6 +820,9 @@ class MraLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1053,6 +1056,7 @@ class MraForMaskedLM(MraPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(MRA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -679,6 +679,9 @@ class NezhaLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1044,6 +1047,7 @@ class NezhaForPreTraining(NezhaPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=NezhaForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1152,6 +1156,7 @@ class NezhaForMaskedLM(NezhaPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(NEZHA_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -428,6 +428,9 @@ class NystromformerLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -666,6 +669,7 @@ class NystromformerForMaskedLM(NystromformerPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(NYSTROMFORMER_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -683,6 +683,9 @@ class QDQBertLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1024,6 +1027,7 @@ class QDQBertLMHeadModel(QDQBertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1190,6 +1194,7 @@ class QDQBertForMaskedLM(QDQBertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(QDQBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -744,6 +744,9 @@ class RoCBertLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1090,6 +1093,7 @@ class RoCBertForPreTraining(RoCBertPreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForPreTraining.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1282,6 +1286,7 @@ class RoCBertForMaskedLM(RoCBertPreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertForMaskedLM.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
def forward(
|
||||
@ -1419,6 +1424,7 @@ class RoCBertForCausalLM(RoCBertPreTrainedModel):
|
||||
# Copied from transformers.models.bert.modeling_bert.BertLMHeadModel.set_output_embeddings
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(ROC_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=CausalLMOutputWithCrossAttentions, config_class=_CONFIG_FOR_DOC)
|
||||
|
@ -729,6 +729,9 @@ class TapasLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -1008,6 +1011,7 @@ class TapasForMaskedLM(TapasPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(TAPAS_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
|
@ -896,6 +896,7 @@ class ViltForMaskedLM(ViltPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.mlm_score.decoder = new_embeddings
|
||||
self.mlm_score.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(VILT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=MaskedLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@ -1042,6 +1043,9 @@ class ViltMLMHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, x):
|
||||
x = self.transform(x)
|
||||
x = self.decoder(x)
|
||||
|
@ -499,6 +499,9 @@ class VisualBertLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -879,6 +882,7 @@ class VisualBertForPreTraining(VisualBertPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(VISUAL_BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@replace_return_docstrings(output_type=VisualBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC)
|
||||
|
@ -619,6 +619,9 @@ class YosoLMPredictionHead(nn.Module):
|
||||
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def _tie_weights(self):
|
||||
self.decoder.bias = self.bias
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.transform(hidden_states)
|
||||
hidden_states = self.decoder(hidden_states)
|
||||
@ -857,6 +860,7 @@ class YosoForMaskedLM(YosoPreTrainedModel):
|
||||
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.cls.predictions.decoder = new_embeddings
|
||||
self.cls.predictions.bias = new_embeddings.bias
|
||||
|
||||
@add_start_docstrings_to_model_forward(YOSO_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
||||
@add_code_sample_docstrings(
|
||||
|
@ -564,6 +564,10 @@ class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineT
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@unittest.skip("Cannot be initialized on meta device as some weights are modified during the initialization")
|
||||
def test_save_load_low_cpu_mem_usage(self):
|
||||
pass
|
||||
|
||||
def test_two_stage_training(self):
|
||||
model_class = DeformableDetrForObjectDetection
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
@ -520,6 +520,10 @@ class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
|
||||
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
|
||||
)
|
||||
|
||||
@unittest.skip("Cannot be initialized on meta device as some weights are modified during the initialization")
|
||||
def test_save_load_low_cpu_mem_usage(self):
|
||||
pass
|
||||
|
||||
|
||||
TOLERANCE = 1e-4
|
||||
|
||||
|
@ -435,6 +435,23 @@ class ModelTesterMixin:
|
||||
max_diff = (model_slow_init.state_dict()[key] - model_fast_init.state_dict()[key]).sum().item()
|
||||
self.assertLessEqual(max_diff, 1e-3, msg=f"{key} not identical")
|
||||
|
||||
def test_save_load_low_cpu_mem_usage(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
for model_class in self.all_model_classes:
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
model_to_save = model_class(config)
|
||||
|
||||
model_to_save.save_pretrained(tmpdirname)
|
||||
|
||||
model = model_class.from_pretrained(
|
||||
tmpdirname,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
# The low_cpu_mem_usage=True causes the model params to be initialized with device=meta. If there are
|
||||
# any unloaded or untied parameters, then trying to move it to device=torch_device will throw an error.
|
||||
model.to(torch_device)
|
||||
|
||||
def test_fast_init_context_manager(self):
|
||||
# 1. Create a dummy class. Should have buffers as well? To make sure we test __init__
|
||||
class MyClass(PreTrainedModel):
|
||||
|
Loading…
Reference in New Issue
Block a user