fix speecht5 failure issue in test_peft_gradient_checkpointing_enable… (#34454)

* fix speecht5 failure issue in test_peft_gradient_checkpointing_enable_disable

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>

* [run-slow] speecht5

---------

Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
Co-authored-by: Matt <rocketknight1@gmail.com>
This commit is contained in:
Wang, Yi 2024-12-03 21:58:54 +08:00 committed by GitHub
parent 7a7f27697a
commit 125de41643
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 1 additions and 13 deletions

View File

@ -2114,7 +2114,7 @@ class SpeechT5Model(SpeechT5PreTrainedModel):
return self.encoder.get_input_embeddings()
if isinstance(self.decoder, SpeechT5DecoderWithTextPrenet):
return self.decoder.get_input_embeddings()
return None
raise NotImplementedError
def set_input_embeddings(self, value):
if isinstance(self.encoder, SpeechT5EncoderWithTextPrenet):

View File

@ -237,12 +237,6 @@ class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
def test_torchscript_simple(self):
pass
@unittest.skip(
reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527"
)
def test_peft_gradient_checkpointing_enable_disable(self):
pass
@require_torch
class SpeechT5ForSpeechToTextTester:
@ -1741,12 +1735,6 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
def test_training_gradient_checkpointing_use_reentrant_false(self):
pass
@unittest.skip(
reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527"
)
def test_peft_gradient_checkpointing_enable_disable(self):
pass
# overwrite from test_modeling_common
def _mock_init_weights(self, module):
if hasattr(module, "weight") and module.weight is not None: