[PEFT] Fix PeftConfig save pretrained when calling add_adapter (#25738)

fix save_pretrained issue + add test
This commit is contained in:
Younes Belkada 2023-08-25 08:19:11 +02:00 committed by GitHub
parent f26099e7b5
commit ae320fa53f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 0 deletions

View File

@ -216,6 +216,9 @@ class PeftAdapterMixin:
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
)
# Retrieve the name or path of the model, one could also use self.config._name_or_path
# but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100
adapter_config.base_model_name_or_path = self.__dict__.get("name_or_path", None)
inject_adapter_in_model(adapter_config, self, adapter_name)
self.set_adapter(adapter_name)

View File

@ -159,6 +159,26 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
# dummy generation
_ = model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
def test_peft_add_adapter_from_pretrained(self):
"""
Simple test that tests if `add_adapter` works as expected
"""
from peft import LoraConfig
for model_id in self.transformers_test_model_ids:
for transformers_class in self.transformers_test_model_classes:
model = transformers_class.from_pretrained(model_id).to(torch_device)
peft_config = LoraConfig(init_lora_weights=False)
model.add_adapter(peft_config)
self.assertTrue(self._check_lora_correctly_converted(model))
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
def test_peft_add_multi_adapter(self):
"""
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if