From ae320fa53f74cc4dfa0e4fc3c95b6129a86b0512 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 25 Aug 2023 08:19:11 +0200 Subject: [PATCH] [`PEFT`] Fix PeftConfig save pretrained when calling `add_adapter` (#25738) fix save_pretrained issue + add test --- .../lib_integrations/peft/peft_mixin.py | 3 +++ .../peft_integration/test_peft_integration.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/transformers/lib_integrations/peft/peft_mixin.py b/src/transformers/lib_integrations/peft/peft_mixin.py index 78e5fd3e9a7..7a1f7c1f582 100644 --- a/src/transformers/lib_integrations/peft/peft_mixin.py +++ b/src/transformers/lib_integrations/peft/peft_mixin.py @@ -216,6 +216,9 @@ class PeftAdapterMixin: f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." ) + # Retrieve the name or path of the model, one could also use self.config._name_or_path + # but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100 + adapter_config.base_model_name_or_path = self.__dict__.get("name_or_path", None) inject_adapter_in_model(adapter_config, self, adapter_name) self.set_adapter(adapter_name) diff --git a/tests/peft_integration/test_peft_integration.py b/tests/peft_integration/test_peft_integration.py index b80912607c0..60bda42fd74 100644 --- a/tests/peft_integration/test_peft_integration.py +++ b/tests/peft_integration/test_peft_integration.py @@ -159,6 +159,26 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin): # dummy generation _ = model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)) + def test_peft_add_adapter_from_pretrained(self): + """ + Simple test that tests if `add_adapter` works as expected + """ + from peft import LoraConfig + + for model_id in self.transformers_test_model_ids: + for transformers_class in self.transformers_test_model_classes: + model = transformers_class.from_pretrained(model_id).to(torch_device) + + peft_config = LoraConfig(init_lora_weights=False) + + model.add_adapter(peft_config) + + self.assertTrue(self._check_lora_correctly_converted(model)) + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device) + self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained)) + def test_peft_add_multi_adapter(self): """ Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if