fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model (#36572)

* fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model

* follow Marc's suggestion to use _tie_weights to fix

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>

* fix review comments.

Signed-off-by: N <matrix.yao@intel.com>

* fix quality

Signed-off-by: N <matrix.yao@intel.com>

---------

Signed-off-by: Yao, Matrix <matrix.yao@intel.com>
Signed-off-by: N <matrix.yao@intel.com>
This commit is contained in:
Yao Matrix 2025-03-19 17:48:47 +08:00 committed by GitHub
parent 706703bba6
commit 4fa91b1be5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 19 additions and 2 deletions

View File

@ -1442,8 +1442,15 @@ class BartModel(BartPreTrainedModel):
def _tie_weights(self):
if self.config.tie_word_embeddings:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247
if self.shared.weight.device == torch.device(
"meta"
) and self.decoder.embed_tokens.weight.device != torch.device("meta"):
self._tie_or_clone_weights(self.encoder.embed_tokens, self.decoder.embed_tokens)
self._tie_or_clone_weights(self.shared, self.decoder.embed_tokens)
else:
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
def get_input_embeddings(self):
return self.shared
@ -1599,6 +1606,11 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self.model._tie_weights()
self._tie_or_clone_weights(self.lm_head, self.model.shared)
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(BART_GENERATION_EXAMPLE)

View File

@ -2479,6 +2479,11 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def _tie_weights(self):
if self.config.tie_word_embeddings:
self.model._tie_weights()
self._tie_or_clone_weights(self.lm_head, self.model.shared)
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
@add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)