mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[PEFT] Support low_cpu_mem_usage option for PEFT loading adapters (#33725)
* [PEFT] Support low_cpu_mem_usage for PEFT loading PEFT added support for low_cpu_mem_usage=True when loading adapters in https://github.com/huggingface/peft/pull/1961. This feature is now available when installing PEFT v0.13.0. With this PR, this option is also supported when loading PEFT adapters directly into transformers models. Additionally, with this PR, https://github.com/huggingface/diffusers/pull/9510 will be unblocked, which implements this option in diffusers. * Fix typo
This commit is contained in:
parent
bf0ffe3d29
commit
6500f78c86
@ -11,10 +11,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..utils import (
|
||||
check_peft_version,
|
||||
find_adapter_config_file,
|
||||
@ -77,6 +80,7 @@ class PeftAdapterMixin:
|
||||
offload_index: Optional[int] = None,
|
||||
peft_config: Dict[str, Any] = None,
|
||||
adapter_state_dict: Optional[Dict[str, "torch.Tensor"]] = None,
|
||||
low_cpu_mem_usage: bool = False,
|
||||
adapter_kwargs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -129,12 +133,27 @@ class PeftAdapterMixin:
|
||||
adapter_state_dict (`Dict[str, torch.Tensor]`, *optional*):
|
||||
The state dict of the adapter to load. This argument is used in case users directly pass PEFT state
|
||||
dicts
|
||||
low_cpu_mem_usage (`bool`, *optional*, defaults to `False`):
|
||||
Reduce memory usage while loading the PEFT adapter. This should also speed up the loading process.
|
||||
Requires PEFT version 0.13.0 or higher.
|
||||
adapter_kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and
|
||||
`find_adapter_config_file` method.
|
||||
"""
|
||||
check_peft_version(min_version=MIN_PEFT_VERSION)
|
||||
|
||||
# peft only supports low_cpu_mem_usage starting from v0.13.0
|
||||
peft_load_kwargs = {}
|
||||
if low_cpu_mem_usage:
|
||||
min_version_lcmu = "0.13.0"
|
||||
if version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu):
|
||||
peft_load_kwargs["low_cpu_mem_usage"] = low_cpu_mem_usage
|
||||
else:
|
||||
raise ValueError(
|
||||
"The version of PEFT you are using does not support `low_cpu_mem_usage` yet, "
|
||||
f"please install PEFT >= {min_version_lcmu}."
|
||||
)
|
||||
|
||||
adapter_name = adapter_name if adapter_name is not None else "default"
|
||||
if adapter_kwargs is None:
|
||||
adapter_kwargs = {}
|
||||
@ -192,7 +211,7 @@ class PeftAdapterMixin:
|
||||
)
|
||||
|
||||
# Create and add fresh new adapters into the model.
|
||||
inject_adapter_in_model(peft_config, self, adapter_name)
|
||||
inject_adapter_in_model(peft_config, self, adapter_name, **peft_load_kwargs)
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
self._hf_peft_config_loaded = True
|
||||
@ -211,7 +230,9 @@ class PeftAdapterMixin:
|
||||
processed_adapter_state_dict[new_key] = value
|
||||
|
||||
# Load state dict
|
||||
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)
|
||||
incompatible_keys = set_peft_model_state_dict(
|
||||
self, processed_adapter_state_dict, adapter_name, **peft_load_kwargs
|
||||
)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
|
@ -12,11 +12,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from packaging import version
|
||||
|
||||
from transformers import AutoModelForCausalLM, OPTForCausalLM
|
||||
from transformers.testing_utils import (
|
||||
@ -478,6 +480,48 @@ class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
# dummy generation
|
||||
_ = model.generate(input_ids=dummy_input)
|
||||
|
||||
def test_peft_add_adapter_with_state_dict_low_cpu_mem_usage(self):
|
||||
"""
|
||||
Check the usage of low_cpu_mem_usage, which is supported in PEFT >= 0.13.0
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
|
||||
min_version_lcmu = "0.13.0"
|
||||
is_lcmu_supported = version.parse(importlib.metadata.version("peft")) >= version.parse(min_version_lcmu)
|
||||
|
||||
for model_id, peft_model_id in zip(self.transformers_test_model_ids, self.peft_test_model_ids):
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
peft_config = LoraConfig()
|
||||
state_dict_path = hf_hub_download(peft_model_id, "adapter_model.bin")
|
||||
dummy_state_dict = torch.load(state_dict_path)
|
||||
|
||||
# this should always work
|
||||
model.load_adapter(
|
||||
adapter_state_dict=dummy_state_dict, peft_config=peft_config, low_cpu_mem_usage=False
|
||||
)
|
||||
|
||||
if is_lcmu_supported:
|
||||
# if supported, this should not raise an error
|
||||
model.load_adapter(
|
||||
adapter_state_dict=dummy_state_dict,
|
||||
adapter_name="other",
|
||||
peft_config=peft_config,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
# after loading, no meta device should be remaining
|
||||
self.assertFalse(any((p.device.type == "meta") for p in model.parameters()))
|
||||
else:
|
||||
err_msg = r"The version of PEFT you are using does not support `low_cpu_mem_usage` yet"
|
||||
with self.assertRaisesRegex(ValueError, err_msg):
|
||||
model.load_adapter(
|
||||
adapter_state_dict=dummy_state_dict,
|
||||
adapter_name="other",
|
||||
peft_config=peft_config,
|
||||
low_cpu_mem_usage=True,
|
||||
)
|
||||
|
||||
def test_peft_from_pretrained_hub_kwargs(self):
|
||||
"""
|
||||
Tests different combinations of PEFT model + from_pretrained + hub kwargs
|
||||
|
Loading…
Reference in New Issue
Block a user