mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
idefics2 enable_input_require_grads not aligned with disable_input_re… (#33194)
* idefics2 enable_input_require_grads not aligned with disable_input_require_grads make peft+idefics2 checkpoints disable fail Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * split test case Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * fix ci failure Signed-off-by: Wang, Yi <yi.a.wang@intel.com> * refine test Signed-off-by: Wang, Yi <yi.a.wang@intel.com> --------- Signed-off-by: Wang, Yi <yi.a.wang@intel.com>
This commit is contained in:
parent
642256de71
commit
74026b473e
@ -1256,6 +1256,10 @@ class Idefics2Model(Idefics2PreTrainedModel):
|
||||
make_inputs_require_grads
|
||||
)
|
||||
|
||||
def disable_input_require_grads(self):
|
||||
self._text_require_grads_hook.remove()
|
||||
self._vision_require_grads_hook.remove()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.text_model.get_input_embeddings()
|
||||
|
||||
@ -1466,6 +1470,10 @@ class Idefics2ForConditionalGeneration(Idefics2PreTrainedModel):
|
||||
make_inputs_require_grads
|
||||
)
|
||||
|
||||
def disable_input_require_grads(self):
|
||||
self._text_require_grads_hook.remove()
|
||||
self._vision_require_grads_hook.remove()
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.model.text_model.get_input_embeddings()
|
||||
|
||||
|
@ -239,6 +239,12 @@ class SpeechT5ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
|
||||
def test_torchscript_simple(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527"
|
||||
)
|
||||
def test_peft_gradient_checkpointing_enable_disable(self):
|
||||
pass
|
||||
|
||||
|
||||
@require_torch
|
||||
class SpeechT5ForSpeechToTextTester:
|
||||
@ -1743,6 +1749,12 @@ class SpeechT5ForSpeechToSpeechTest(ModelTesterMixin, unittest.TestCase):
|
||||
def test_training_gradient_checkpointing_use_reentrant_false(self):
|
||||
pass
|
||||
|
||||
@unittest.skip(
|
||||
reason="Model returns None for input_embeds, check: https://github.com/huggingface/transformers/issues/33527"
|
||||
)
|
||||
def test_peft_gradient_checkpointing_enable_disable(self):
|
||||
pass
|
||||
|
||||
# overwrite from test_modeling_common
|
||||
def _mock_init_weights(self, module):
|
||||
if hasattr(module, "weight") and module.weight is not None:
|
||||
|
@ -403,6 +403,44 @@ class ModelTesterMixin:
|
||||
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
|
||||
)
|
||||
|
||||
def test_peft_gradient_checkpointing_enable_disable(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
||||
for model_class in self.all_model_classes:
|
||||
if not model_class.supports_gradient_checkpointing:
|
||||
continue
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = model_class(config)
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
# check enable works
|
||||
model._hf_peft_config_loaded = True
|
||||
try:
|
||||
model.gradient_checkpointing_enable()
|
||||
except NotImplementedError:
|
||||
continue
|
||||
|
||||
self.assertTrue(model.is_gradient_checkpointing)
|
||||
|
||||
# Loop over all modules and check that relevant modules have gradient_checkpointing set to True
|
||||
for n, m in model.named_modules():
|
||||
if hasattr(m, "gradient_checkpointing"):
|
||||
self.assertTrue(
|
||||
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to True"
|
||||
)
|
||||
|
||||
# check disable works
|
||||
model.gradient_checkpointing_disable()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
# Loop over all modules and check that relevant modules have gradient_checkpointing set to False
|
||||
for n, m in model.named_modules():
|
||||
if hasattr(m, "gradient_checkpointing"):
|
||||
self.assertFalse(
|
||||
m.gradient_checkpointing, f"Module {n} does not have gradient_checkpointing set to False"
|
||||
)
|
||||
|
||||
@is_flaky(description="low likelihood of failure, reason not yet discovered")
|
||||
def test_save_load_fast_init_from_base(self):
|
||||
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
|
||||
|
Loading…
Reference in New Issue
Block a user