mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model (#36572)
* fix "Cannot copy out of meta tensor; no data!" issue for BartForConditionalGeneration model * follow Marc's suggestion to use _tie_weights to fix Signed-off-by: Yao, Matrix <matrix.yao@intel.com> * fix review comments. Signed-off-by: N <matrix.yao@intel.com> * fix quality Signed-off-by: N <matrix.yao@intel.com> --------- Signed-off-by: Yao, Matrix <matrix.yao@intel.com> Signed-off-by: N <matrix.yao@intel.com>
This commit is contained in:
parent
706703bba6
commit
4fa91b1be5
@ -1442,8 +1442,15 @@ class BartModel(BartPreTrainedModel):
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
# Some model checkpoints like "facebook/bart-large-cnn"'s embedding weight is in decoder.embed_tokens, need check here, see issue #36247
|
||||
if self.shared.weight.device == torch.device(
|
||||
"meta"
|
||||
) and self.decoder.embed_tokens.weight.device != torch.device("meta"):
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.decoder.embed_tokens)
|
||||
self._tie_or_clone_weights(self.shared, self.decoder.embed_tokens)
|
||||
else:
|
||||
self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
|
||||
self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.shared
|
||||
@ -1599,6 +1606,11 @@ class BartForConditionalGeneration(BartPreTrainedModel, GenerationMixin):
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.model._tie_weights()
|
||||
self._tie_or_clone_weights(self.lm_head, self.model.shared)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_end_docstrings(BART_GENERATION_EXAMPLE)
|
||||
|
@ -2479,6 +2479,11 @@ class BigBirdPegasusForConditionalGeneration(BigBirdPegasusPreTrainedModel, Gene
|
||||
def set_output_embeddings(self, new_embeddings):
|
||||
self.lm_head = new_embeddings
|
||||
|
||||
def _tie_weights(self):
|
||||
if self.config.tie_word_embeddings:
|
||||
self.model._tie_weights()
|
||||
self._tie_or_clone_weights(self.lm_head, self.model.shared)
|
||||
|
||||
@add_start_docstrings_to_model_forward(BIGBIRD_PEGASUS_INPUTS_DOCSTRING)
|
||||
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
|
||||
@add_end_docstrings(BIGBIRD_PEGASUS_GENERATION_EXAMPLE)
|
||||
|
Loading…
Reference in New Issue
Block a user