diff --git a/src/transformers/models/bart/modeling_bart.py b/src/transformers/models/bart/modeling_bart.py index bec24c4f78f..ac9945df71b 100755 --- a/src/transformers/models/bart/modeling_bart.py +++ b/src/transformers/models/bart/modeling_bart.py @@ -1442,8 +1442,15 @@ class BartModel(BartPreTrainedModel): def _tie_weights(self): if self.config.tie_word_embeddings: - self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) - self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) + # Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247 + if self.shared.weight.device == torch.device( + "meta" + ) and self.decoder.embed_tokens.weight.device != torch.device("meta"): + self._tie_or_clone_weights(self.encoder.embed_tokens, self.decoder.embed_tokens) + self._tie_or_clone_weights(self.shared, self.decoder.embed_tokens) + else: + self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared) + self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared) def get_input_embeddings(self): return self.shared @@ -1599,6 +1606,11 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin): def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def _tie_weights(self): + if self.config.tie_word_embeddings: + self.model._tie_weights() + self._tie_or_clone_weights(self.lm_head, self.model.shared) + @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @add_end_docstrings(BART_GENERATION_EXAMPLE) diff --git a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py index 61634901289..935e5ee9fd6 100755 --- a/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +++ b/src/transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py @@ -2479,6 +2479,11 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings + def _tie_weights(self): + if self.config.tie_word_embeddings: + self.model._tie_weights() + self._tie_or_clone_weights(self.lm_head, self.model.shared) + @add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING) @replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC) @add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)