mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 02:31:11 +06:00
[PEFT
] Fix PEFT + gradient checkpointing (#25846)
* fix PEFT + gradient checkpointing
* add disable RG
* polish tests
* fix comment
* Revert "fix comment"
This reverts commit b85386f50d
.
* final explanations and tests
This commit is contained in:
parent
ac957f69cc
commit
7c63e6fc8c
@ -1750,6 +1750,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
|
||||
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
|
||||
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
|
||||
# the gradients to make sure the gradient flows.
|
||||
self.enable_input_require_grads()
|
||||
|
||||
def gradient_checkpointing_disable(self):
|
||||
"""
|
||||
Deactivates gradient checkpointing for the current model.
|
||||
@ -1760,6 +1767,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if self.supports_gradient_checkpointing:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
if getattr(self, "_hf_peft_config_loaded", False):
|
||||
self.disable_input_require_grads()
|
||||
|
||||
@property
|
||||
def is_gradient_checkpointing(self) -> bool:
|
||||
"""
|
||||
|
@ -179,6 +179,52 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
||||
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
||||
|
||||
def test_peft_add_adapter_training_gradient_checkpointing(self):
|
||||
"""
|
||||
Simple test that tests if `add_adapter` works as expected when training with
|
||||
gradient checkpointing.
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
|
||||
for model_id in self.transformers_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
peft_config = LoraConfig(init_lora_weights=False)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||
|
||||
# When attaching adapters the input embeddings will stay frozen, this will
|
||||
# lead to the output embedding having requires_grad=False.
|
||||
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||
frozen_output = model.get_input_embeddings()(dummy_input)
|
||||
self.assertTrue(frozen_output.requires_grad is False)
|
||||
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# Since here we attached the hook, the input should have requires_grad to set
|
||||
# properly
|
||||
non_frozen_output = model.get_input_embeddings()(dummy_input)
|
||||
self.assertTrue(non_frozen_output.requires_grad is True)
|
||||
|
||||
# To repro the Trainer issue
|
||||
dummy_input.requires_grad = False
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if "lora" in name.lower():
|
||||
self.assertTrue(param.requires_grad)
|
||||
|
||||
logits = model(dummy_input).logits
|
||||
loss = logits.mean()
|
||||
loss.backward()
|
||||
|
||||
for name, param in model.named_parameters():
|
||||
if param.requires_grad:
|
||||
self.assertTrue("lora" in name.lower())
|
||||
self.assertTrue(param.grad is not None)
|
||||
|
||||
def test_peft_add_multi_adapter(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
||||
|
Loading…
Reference in New Issue
Block a user