diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9279950d3ac..41ea192cbb9 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -1750,6 +1750,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") self.apply(partial(self._set_gradient_checkpointing, value=True)) + if getattr(self, "_hf_peft_config_loaded", False): + # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True + # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 + # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate + # the gradients to make sure the gradient flows. + self.enable_input_require_grads() + def gradient_checkpointing_disable(self): """ Deactivates gradient checkpointing for the current model. @@ -1760,6 +1767,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix if self.supports_gradient_checkpointing: self.apply(partial(self._set_gradient_checkpointing, value=False)) + if getattr(self, "_hf_peft_config_loaded", False): + self.disable_input_require_grads() + @property def is_gradient_checkpointing(self) -> bool: """ diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index 60bda42fd74..b238ce25cb2 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -179,6 +179,52 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device) self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained)) + def test_peft_add_adapter_training_gradient_checkpointing(self): + """ + Simple test that tests if `add_adapter` works as expected when training with + gradient checkpointing. + """ + from peft import LoraConfig + + for model_id in self.transformers_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + + peft_config = LoraConfig(init_lora_weights=False) + + model.add_adapter(peft_config) + + self.assertTrue(self._check_lora_correctly_converted(model)) + + # When attaching adapters the input embeddings will stay frozen, this will + # lead to the output embedding having requires_grad=False. + dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device) + frozen_output = model.get_input_embeddings()(dummy_input) + self.assertTrue(frozen_output.requires_grad is False) + + model.gradient_checkpointing_enable() + + # Since here we attached the hook, the input should have requires_grad to set + # properly + non_frozen_output = model.get_input_embeddings()(dummy_input) + self.assertTrue(non_frozen_output.requires_grad is True) + + # To repro the Trainer issue + dummy_input.requires_grad = False + + for name, param in model.named_parameters(): + if "lora" in name.lower(): + self.assertTrue(param.requires_grad) + + logits = model(dummy_input).logits + loss = logits.mean() + loss.backward() + + for name, param in model.named_parameters(): + if param.requires_grad: + self.assertTrue("lora" in name.lower()) + self.assertTrue(param.grad is not None) + def test_peft_add_multi_adapter(self): """ Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if