mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 21:30:07 +06:00
[PEFT
] Fix PeftConfig save pretrained when calling add_adapter
(#25738)
fix save_pretrained issue + add test
This commit is contained in:
parent
f26099e7b5
commit
ae320fa53f
@ -216,6 +216,9 @@ class PeftAdapterMixin:
|
||||
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
|
||||
)
|
||||
|
||||
# Retrieve the name or path of the model, one could also use self.config._name_or_path
|
||||
# but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100
|
||||
adapter_config.base_model_name_or_path = self.__dict__.get("name_or_path", None)
|
||||
inject_adapter_in_model(adapter_config, self, adapter_name)
|
||||
|
||||
self.set_adapter(adapter_name)
|
||||
|
@ -159,6 +159,26 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
# dummy generation
|
||||
_ = model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
||||
|
||||
def test_peft_add_adapter_from_pretrained(self):
|
||||
"""
|
||||
Simple test that tests if `add_adapter` works as expected
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
|
||||
for model_id in self.transformers_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
peft_config = LoraConfig(init_lora_weights=False)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
model_from_pretrained = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
||||
self.assertTrue(self._check_lora_correctly_converted(model_from_pretrained))
|
||||
|
||||
def test_peft_add_multi_adapter(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
||||
|
Loading…
Reference in New Issue
Block a user