[PEFT] Fix PEFT + gradient checkpointing (#25846)

* fix PEFT + gradient checkpointing

* add disable RG

* polish tests

* fix comment

* Revert "fix comment"

This reverts commit b85386f50d.

* final explanations and tests
This commit is contained in:
Younes Belkada 2023-09-14 13:01:58 +02:00 committed by GitHub
parent ac957f69cc
commit 7c63e6fc8c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 56 additions and 0 deletions

View File

@ -1750,6 +1750,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
def gradient_checkpointing_disable(self):
"""
Deactivates gradient checkpointing for the current model.
@ -1760,6 +1767,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if self.supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
if getattr(self, "_hf_peft_config_loaded", False):
self.disable_input_require_grads()
@property
def is_gradient_checkpointing(self) -> bool:
"""

View File

@ -179,6 +179,52 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
def test_peft_add_adapter_training_gradient_checkpointing(self):
"""
Simple test that tests if `add_adapter` works as expected when training with
gradient checkpointing.
"""
from peft import LoraConfig
for model_id in self.transformers_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)
peft_config = LoraConfig(init_lora_weights=False)
model.add_adapter(peft_config)
self.assertTrue(self._check_lora_correctly_converted(model))
# When attaching adapters the input embeddings will stay frozen, this will
# lead to the output embedding having requires_grad=False.
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
frozen_output = model.get_input_embeddings()(dummy_input)
self.assertTrue(frozen_output.requires_grad is False)
model.gradient_checkpointing_enable()
# Since here we attached the hook, the input should have requires_grad to set
# properly
non_frozen_output = model.get_input_embeddings()(dummy_input)
self.assertTrue(non_frozen_output.requires_grad is True)
# To repro the Trainer issue
dummy_input.requires_grad = False
for name, param in model.named_parameters():
if "lora" in name.lower():
self.assertTrue(param.requires_grad)
logits = model(dummy_input).logits
loss = logits.mean()
loss.backward()
for name, param in model.named_parameters():
if param.requires_grad:
self.assertTrue("lora" in name.lower())
self.assertTrue(param.grad is not None)
def test_peft_add_multi_adapter(self):
"""
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if