mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[PEFT
] Peft integration alternative design (#25077)
* a draft version * v2 integration * fix * make it more generic and works for IA3 * add set adapter and multiple adapters support * fixup * adapt a bit * oops * oops * oops * adapt more * fix * add more refactor * now works with model class * change it to instance method as it causes issues with `jit`. * add CR * change method name * add `add_adapter` method * clean up * Update src/transformers/adapters/peft_mixin.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * add moe utils * fixup * Update src/transformers/adapters/peft_mixin.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * adapt * oops * fixup * add is_peft_available * remove `requires_backend` * trainer compatibility * fixup + docstring * more details * trigger CI * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * Update src/transformers/modeling_utils.py * fixup + is_main_process * added `save_peft_format` in save_pretrained * up * fix nits here and there * nits here and there. * docs * revert `encoding="utf-8"` * comment * added slow tests before the PEFT release. * fixup and nits * let's be on the safe zone * added more comments * v1 docs * add remaining docs * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * move to `lib_integrations` * fixup * this time fixup * Apply suggestions from code review Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> * address final comments * refactor to use `token` * add PEFT to DockerFile for slow tests. * added pipeline support. --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
This commit is contained in:
parent
ef1534252f
commit
faed2ca46f
@ -44,6 +44,8 @@ RUN python3 -m pip install -U "itsdangerous<2.1.0"
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/accelerate@main#egg=accelerate
|
||||
|
||||
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/peft@main#egg=peft
|
||||
|
||||
# Add bitsandbytes for mixed int8 testing
|
||||
RUN python3 -m pip install --no-cache-dir bitsandbytes
|
||||
|
||||
|
@ -19,6 +19,8 @@
|
||||
title: Train with a script
|
||||
- local: accelerate
|
||||
title: Set up distributed training with 🤗 Accelerate
|
||||
- local: peft
|
||||
title: Load and train adapters with 🤗 PEFT
|
||||
- local: model_sharing
|
||||
title: Share your model
|
||||
- local: transformers_agents
|
||||
|
216
docs/source/en/peft.md
Normal file
216
docs/source/en/peft.md
Normal file
@ -0,0 +1,216 @@
|
||||
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# Load adapters with 🤗 PEFT
|
||||
|
||||
[[open-in-colab]]
|
||||
|
||||
[Parameter-Efficient Fine Tuning (PEFT)](https://huggingface.co/blog/peft) methods freeze the pretrained model parameters during fine-tuning and add a small number of trainable parameters (the adapters) on top of it. The adapters are trained to learn task-specific information. This approach has been shown to be very memory-efficient with lower compute usage while producing results comparable to a fully fine-tuned model.
|
||||
|
||||
Adapters trained with PEFT are also usually an order of magnitude smaller than the full model, making it convenient to share, store, and load them.
|
||||
|
||||
<div class="flex flex-col justify-center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/peft/PEFT-hub-screenshot.png"/>
|
||||
<figcaption class="text-center">The adapter weights for a OPTForCausalLM model stored on the Hub are only ~6MB compared to the full size of the model weights, which can be ~700MB.</figcaption>
|
||||
</div>
|
||||
|
||||
If you're interested in learning more about the 🤗 PEFT library, check out the [documentation](https://huggingface.co/docs/peft/index).
|
||||
|
||||
## Setup
|
||||
|
||||
Get started by installing 🤗 PEFT:
|
||||
|
||||
```bash
|
||||
pip install peft
|
||||
```
|
||||
|
||||
If you want to try out the brand new features, you might be interested in installing the library from source:
|
||||
|
||||
```bash
|
||||
pip install git+https://github.com/huggingface/peft.git
|
||||
```
|
||||
|
||||
## Supported PEFT models
|
||||
|
||||
🤗 Transformers natively supports some PEFT methods, meaning you can load adapter weights stored locally or on the Hub and easily run or train them with a few lines of code. The following methods are supported:
|
||||
|
||||
- [Low Rank Adapters](https://huggingface.co/docs/peft/conceptual_guides/lora)
|
||||
- [IA3](https://huggingface.co/docs/peft/conceptual_guides/ia3)
|
||||
- [AdaLoRA](https://arxiv.org/abs/2303.10512)
|
||||
|
||||
If you want to use other PEFT methods, such as prompt learning or prompt tuning, or about the 🤗 PEFT library in general, please refer to the [documentation](https://huggingface.co/docs/peft/index).
|
||||
|
||||
|
||||
## Load a PEFT adapter
|
||||
|
||||
To load and use a PEFT adapter model from 🤗 Transformers, make sure the Hub repository or local directory contains an `adapter_config.json` file and the adapter weights, as shown in the example image above. Then you can load the PEFT adapter model using the `AutoModelFor` class. For example, to load a PEFT adapter model for causal language modeling:
|
||||
|
||||
1. specify the PEFT model id
|
||||
2. pass it to the [`AutoModelForCausalLM`] class
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
peft_model_id = "ybelkada/opt-350m-lora"
|
||||
model = AutoModelForCausalLM.from_pretrained(peft_model_id)
|
||||
```
|
||||
|
||||
<Tip>
|
||||
|
||||
You can load a PEFT adapter with either an `AutoModelFor` class or the base model class like `OPTForCausalLM` or `LlamaForCausalLM`.
|
||||
|
||||
</Tip>
|
||||
|
||||
You can also load a PEFT adapter by calling the `load_adapter` method:
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
peft_model_id = "ybelkada/opt-350m-lora"
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
model.load_adapter(peft_model_id)
|
||||
```
|
||||
|
||||
## Load in 8bit or 4bit
|
||||
|
||||
The `bitsandbytes` integration supports 8bit and 4bit precision data types, which are useful for loading large models because it saves memory (see the `bitsandbytes` integration [guide](./quantization#bitsandbytes-integration) to learn more). Add the `load_in_8bit` or `load_in_4bit` parameters to [`~PreTrainedModel.from_pretrained`] and set `device_map="auto"` to effectively distribute the model to your hardware:
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
peft_model_id = "ybelkada/opt-350m-lora"
|
||||
model = AutoModelForCausalLM.from_pretrained(peft_model_id, device_map="auto", load_in_8bit=True)
|
||||
```
|
||||
|
||||
## Add a new adapter
|
||||
|
||||
You can use [`~peft.PeftModel.add_adapter`] to add a new adapter to a model with an existing adapter as long as the new adapter is the same type as the current one. For example, if you have an existing LoRA adapter attached to a model:
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
|
||||
from peft import PeftConfig
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
|
||||
lora_config = LoraConfig(
|
||||
target_modules=["q_proj", "k_proj"],
|
||||
init_lora_weights=False
|
||||
)
|
||||
|
||||
model.add_adapter(lora_config, adapter_name="adapter_1")
|
||||
```
|
||||
|
||||
To add a new adapter:
|
||||
|
||||
```py
|
||||
# attach new adapter with same config
|
||||
model.add_adapter(lora_config, adapter_name="adapter_2")
|
||||
```
|
||||
|
||||
Now you can use [`~peft.PeftModel.set_adapter`] to set which adapter to use:
|
||||
|
||||
```py
|
||||
# use adapter_1
|
||||
model.set_adapter("adapter_1")
|
||||
output = model.generate(**inputs)
|
||||
print(tokenizer.decode(output_disabled[0], skip_special_tokens=True))
|
||||
|
||||
# use adapter_2
|
||||
model.set_adapter("adapter_2")
|
||||
output_enabled = model.generate(**inputs)
|
||||
print(tokenizer.decode(output_enabled[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Enable and disable adapters
|
||||
|
||||
Once you've added an adapter to a model, you can enable or disable the adapter module. To enable the adapter module:
|
||||
|
||||
```py
|
||||
from transformers import AutoModelForCausalLM, OPTForCausalLM, AutoTokenizer
|
||||
from peft import PeftConfig
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
adapter_model_id = "ybelkada/opt-350m-lora"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
text = "Hello"
|
||||
inputs = tokenizer(text, return_tensors="pt")
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id)
|
||||
peft_config = PeftConfig.from_pretrained(adapter_model_id)
|
||||
|
||||
# to initiate with random weights
|
||||
peft_config.init_lora_weights = False
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
model.enable_adapters()
|
||||
output = model.generate(**inputs)
|
||||
```
|
||||
|
||||
To disable the adapter module:
|
||||
|
||||
```py
|
||||
model.disable_adapters()
|
||||
output = model.generate(**inputs)
|
||||
```
|
||||
|
||||
## Train a PEFT adapter
|
||||
|
||||
PEFT adapters are supported by the [`Trainer`] class so that you can train an adapter for your specific use case. It only requires adding a few more lines of code. For example, to train a LoRA adapter:
|
||||
|
||||
<Tip>
|
||||
|
||||
If you aren't familiar with fine-tuning a model with [`Trainer`], take a look at the [Fine-tune a pretrained model](training) tutorial.
|
||||
|
||||
</Tip>
|
||||
|
||||
1. Define your adapter configuration with the task type and hyperparameters (see [`~peft.LoraConfig`] for more details about what the hyperparameters do).
|
||||
|
||||
```py
|
||||
from peft import LoraConfig
|
||||
|
||||
peft_config = LoraConfig(
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
r=64,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
```
|
||||
|
||||
2. Add adapter to the model.
|
||||
|
||||
```py
|
||||
model.add_adapter(peft_config)
|
||||
```
|
||||
|
||||
3. Now you can pass the model to [`Trainer`]!
|
||||
|
||||
```py
|
||||
trainer = Trainer(model=model, ...)
|
||||
trainer.train()
|
||||
```
|
||||
|
||||
To save your trained adapter and load it back:
|
||||
|
||||
```py
|
||||
model.save_pretrained(save_dir)
|
||||
model = AutoModelForCausalLM.from_pretrained(save_dir)
|
||||
```
|
||||
|
||||
<!--
|
||||
TODO: (@younesbelkada @stevhliu)
|
||||
- Link to PEFT docs for further details
|
||||
- Trainer
|
||||
- 8-bit / 4-bit examples ?
|
||||
-->
|
@ -111,6 +111,8 @@ _import_structure = {
|
||||
"is_tensorboard_available",
|
||||
"is_wandb_available",
|
||||
],
|
||||
"lib_integrations": [],
|
||||
"lib_integrations.peft": [],
|
||||
"modelcard": ["ModelCard"],
|
||||
"modeling_tf_pytorch_utils": [
|
||||
"convert_tf_weight_name_to_pt_weight_name",
|
||||
|
14
src/transformers/lib_integrations/__init__.py
Normal file
14
src/transformers/lib_integrations/__init__.py
Normal file
@ -0,0 +1,14 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .peft import PeftAdapterMixin
|
15
src/transformers/lib_integrations/peft/__init__.py
Normal file
15
src/transformers/lib_integrations/peft/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .peft_mixin import PeftAdapterMixin
|
390
src/transformers/lib_integrations/peft/peft_mixin.py
Normal file
390
src/transformers/lib_integrations/peft/peft_mixin.py
Normal file
@ -0,0 +1,390 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
from typing import Optional
|
||||
|
||||
from ...utils import (
|
||||
check_peft_version,
|
||||
find_adapter_config_file,
|
||||
is_accelerate_available,
|
||||
is_peft_available,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import dispatch_model
|
||||
from accelerate.utils import get_balanced_memory, infer_auto_device_map
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class PeftAdapterMixin:
|
||||
"""
|
||||
A class containing all functions for loading and using adapters weights that are supported in PEFT library. For
|
||||
more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT
|
||||
library: https://huggingface.co/docs/peft/index
|
||||
|
||||
Currently supported PEFT methods are all non-prefix tuning methods. Below is the list of supported PEFT methods
|
||||
that anyone can load, train and run with this mixin class:
|
||||
- Low Rank Adapters (LoRA): https://huggingface.co/docs/peft/conceptual_guides/lora
|
||||
- IA3: https://huggingface.co/docs/peft/conceptual_guides/ia3
|
||||
- AdaLora: https://arxiv.org/abs/2303.10512
|
||||
|
||||
Other PEFT models such as prompt tuning, prompt learning are out of scope as these adapters are not "injectable"
|
||||
into a torch module. For using these methods, please refer to the usage guide of PEFT library.
|
||||
|
||||
With this mixin, if the correct PEFT version is installed, it is possible to:
|
||||
|
||||
- Load an adapter stored on a local path or in a remote Hub repository, and inject it in the model
|
||||
- Attach new adapters in the model and train them with Trainer or by your own.
|
||||
- Attach multiple adapters and iteratively activate / deactivate them
|
||||
- Activate / deactivate all adapters from the model.
|
||||
- Get the `state_dict` of the active adapter.
|
||||
"""
|
||||
|
||||
_hf_peft_config_loaded = False
|
||||
|
||||
def load_adapter(
|
||||
self,
|
||||
peft_model_id: str,
|
||||
adapter_name: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
token: Optional[str] = None,
|
||||
device_map: Optional[str] = "auto",
|
||||
max_memory: Optional[str] = None,
|
||||
offload_folder: Optional[str] = None,
|
||||
offload_index: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we
|
||||
invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft
|
||||
|
||||
Requires peft as a backend to load the adapter weights.
|
||||
|
||||
Args:
|
||||
peft_model_id (`str`):
|
||||
The identifier of the model to look for on the Hub, or a local path to the saved adapter config file
|
||||
and adapter weights.
|
||||
adapter_name (`str`, *optional*):
|
||||
The adapter name to use. If not set, will use the default adapter.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
|
||||
<Tip>
|
||||
|
||||
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
|
||||
|
||||
</Tip>
|
||||
|
||||
token (`str`, `optional`):
|
||||
Whether to use authentication token to load the remote folder. Userful to load private repositories
|
||||
that are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to
|
||||
cache it.
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
|
||||
like `1`) on which the model will be allocated, the device map will map the entire model to this
|
||||
device. Passing `device_map = 0` means put the whole model on GPU 0.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
offload_folder (`str` or `os.PathLike`, `optional`):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_index (`int`, `optional`):
|
||||
`offload_index` argument to be passed to `accelerate.dispatch_model` method.
|
||||
"""
|
||||
check_peft_version(min_version="0.4.0")
|
||||
|
||||
adapter_name = adapter_name if adapter_name is not None else "default"
|
||||
|
||||
from peft import PeftConfig, inject_adapter_in_model, load_peft_weights
|
||||
from peft.utils import set_peft_model_state_dict
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
self._hf_peft_config_loaded = True
|
||||
elif adapter_name in self.peft_config:
|
||||
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
|
||||
|
||||
adapter_config_file = find_adapter_config_file(
|
||||
peft_model_id,
|
||||
revision=revision,
|
||||
token=token,
|
||||
)
|
||||
|
||||
if adapter_config_file is None:
|
||||
raise ValueError(
|
||||
f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the "
|
||||
"adapter model."
|
||||
)
|
||||
|
||||
loaded_peft_config = PeftConfig.from_pretrained(
|
||||
peft_model_id,
|
||||
revision=revision,
|
||||
use_auth_token=token,
|
||||
)
|
||||
|
||||
# Create and add fresh new adapters into the model.
|
||||
inject_adapter_in_model(loaded_peft_config, self, adapter_name)
|
||||
|
||||
adapter_state_dict = load_peft_weights(peft_model_id, revision=revision, use_auth_token=token)
|
||||
|
||||
# We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility
|
||||
processed_adapter_state_dict = {}
|
||||
prefix = "base_model.model."
|
||||
for key, value in adapter_state_dict.items():
|
||||
if key.startswith(prefix):
|
||||
new_key = key[len(prefix) :]
|
||||
else:
|
||||
new_key = key
|
||||
processed_adapter_state_dict[new_key] = value
|
||||
|
||||
# Load state dict
|
||||
incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name)
|
||||
|
||||
if incompatible_keys is not None:
|
||||
# check only for unexpected keys
|
||||
if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: "
|
||||
f" {incompatible_keys.unexpected_keys}. "
|
||||
)
|
||||
|
||||
# Re-dispatch model and hooks in case the model is offloaded to CPU / Disk.
|
||||
if (
|
||||
(getattr(self, "hf_device_map", None) is not None)
|
||||
and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0)
|
||||
and len(self.peft_config) == 1
|
||||
):
|
||||
self._dispatch_accelerate_model(
|
||||
device_map=device_map,
|
||||
max_memory=max_memory,
|
||||
offload_folder=offload_folder,
|
||||
offload_index=offload_index,
|
||||
)
|
||||
|
||||
def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None:
|
||||
r"""
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
|
||||
Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default
|
||||
name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the
|
||||
default adapter name).
|
||||
|
||||
Args:
|
||||
adapter_config (`~peft.PeftConfig`):
|
||||
The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts
|
||||
methods
|
||||
adapter_name (`str`, *optional*, defaults to `"default"`):
|
||||
The name of the adapter to add. If no name is passed, a default name is assigned to the adapter.
|
||||
"""
|
||||
check_peft_version(min_version="0.4.0")
|
||||
|
||||
from peft import PeftConfig, inject_adapter_in_model
|
||||
|
||||
adapter_name = adapter_name or "default"
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
self._hf_peft_config_loaded = True
|
||||
elif adapter_name in self.peft_config:
|
||||
raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.")
|
||||
|
||||
if not isinstance(adapter_config, PeftConfig):
|
||||
raise ValueError(
|
||||
f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead."
|
||||
)
|
||||
|
||||
inject_adapter_in_model(adapter_config, self, adapter_name)
|
||||
|
||||
self.set_adapter(adapter_name)
|
||||
|
||||
def set_adapter(self, adapter_name: str) -> None:
|
||||
"""
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
|
||||
Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters.
|
||||
|
||||
Args:
|
||||
adapter_name (`str`):
|
||||
The name of the adapter to set.
|
||||
"""
|
||||
check_peft_version(min_version="0.4.0")
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
elif adapter_name not in self.peft_config:
|
||||
raise ValueError(
|
||||
f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}"
|
||||
)
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
_adapters_has_been_set = False
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.active_adapter = adapter_name
|
||||
_adapters_has_been_set = True
|
||||
|
||||
if not _adapters_has_been_set:
|
||||
raise ValueError(
|
||||
"Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters."
|
||||
)
|
||||
|
||||
def disable_adapters(self) -> None:
|
||||
r"""
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
|
||||
Disable all adapters that are attached to the model. This leads to inferring with the base model only.
|
||||
"""
|
||||
check_peft_version(min_version="0.4.0")
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.disable_adapters = True
|
||||
|
||||
def enable_adapters(self) -> None:
|
||||
"""
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
|
||||
Enable adapters that are attached to the model. The model will use `self.active_adapter()`
|
||||
"""
|
||||
check_peft_version(min_version="0.4.0")
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
module.disable_adapters = False
|
||||
|
||||
def active_adapter(self) -> str:
|
||||
"""
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
|
||||
Gets the current active adapter of the model.
|
||||
"""
|
||||
check_peft_version(min_version="0.4.0")
|
||||
|
||||
if not is_peft_available():
|
||||
raise ImportError("PEFT is not available. Please install PEFT to use this function: `pip install peft`.")
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
for _, module in self.named_modules():
|
||||
if isinstance(module, BaseTunerLayer):
|
||||
return module.active_adapter
|
||||
|
||||
def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict:
|
||||
"""
|
||||
If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT
|
||||
official documentation: https://huggingface.co/docs/peft
|
||||
|
||||
Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter.
|
||||
If no adapter_name is passed, the active adapter is used.
|
||||
|
||||
Args:
|
||||
adapter_name (`str`, *optional*):
|
||||
The name of the adapter to get the state dict from. If no name is passed, the active adapter is used.
|
||||
"""
|
||||
check_peft_version(min_version="0.4.0")
|
||||
|
||||
if not self._hf_peft_config_loaded:
|
||||
raise ValueError("No adapter loaded. Please load an adapter first.")
|
||||
|
||||
from peft import get_peft_model_state_dict
|
||||
|
||||
if adapter_name is None:
|
||||
adapter_name = self.active_adapter()
|
||||
|
||||
adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name)
|
||||
return adapter_state_dict
|
||||
|
||||
def _dispatch_accelerate_model(
|
||||
self,
|
||||
device_map: str,
|
||||
max_memory: Optional[int] = None,
|
||||
offload_folder: Optional[str] = None,
|
||||
offload_index: Optional[int] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Optionnal re-dispatch the model and attach new hooks to the model in case the model has been loaded with
|
||||
accelerate (i.e. with `device_map=xxx`)
|
||||
|
||||
Args:
|
||||
device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
|
||||
A map that specifies where each submodule should go. It doesn't need to be refined to each
|
||||
parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
|
||||
same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
|
||||
like `1`) on which the model will be allocated, the device map will map the entire model to this
|
||||
device. Passing `device_map = 0` means put the whole model on GPU 0.
|
||||
|
||||
To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
|
||||
more information about each option see [designing a device
|
||||
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
||||
max_memory (`Dict`, *optional*):
|
||||
A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
|
||||
GPU and the available CPU RAM if unset.
|
||||
offload_folder (`str` or `os.PathLike`, *optional*):
|
||||
If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
|
||||
offload_index (`int`, *optional*):
|
||||
The offload_index argument to be passed to `accelerate.dispatch_model` method.
|
||||
"""
|
||||
dispatch_model_kwargs = {}
|
||||
# Safety checker for previous `accelerate` versions
|
||||
# `offload_index` was introduced in https://github.com/huggingface/accelerate/pull/873/
|
||||
if "offload_index" in inspect.signature(dispatch_model).parameters:
|
||||
dispatch_model_kwargs["offload_index"] = offload_index
|
||||
|
||||
no_split_module_classes = self._no_split_modules
|
||||
|
||||
if device_map != "sequential":
|
||||
max_memory = get_balanced_memory(
|
||||
self,
|
||||
max_memory=max_memory,
|
||||
no_split_module_classes=no_split_module_classes,
|
||||
low_zero=(device_map == "balanced_low_0"),
|
||||
)
|
||||
if isinstance(device_map, str):
|
||||
device_map = infer_auto_device_map(
|
||||
self, max_memory=max_memory, no_split_module_classes=no_split_module_classes
|
||||
)
|
||||
dispatch_model(
|
||||
self,
|
||||
device_map=device_map,
|
||||
offload_dir=offload_folder,
|
||||
**dispatch_model_kwargs,
|
||||
)
|
@ -38,6 +38,7 @@ from .configuration_utils import PretrainedConfig
|
||||
from .deepspeed import deepspeed_config, is_deepspeed_zero3_enabled
|
||||
from .dynamic_module_utils import custom_object_save
|
||||
from .generation import GenerationConfig, GenerationMixin
|
||||
from .lib_integrations import PeftAdapterMixin
|
||||
from .pytorch_utils import ( # noqa: F401
|
||||
Conv1D,
|
||||
apply_chunking_to_forward,
|
||||
@ -48,6 +49,8 @@ from .pytorch_utils import ( # noqa: F401
|
||||
prune_linear_layer,
|
||||
)
|
||||
from .utils import (
|
||||
ADAPTER_SAFE_WEIGHTS_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
DUMMY_INPUTS,
|
||||
FLAX_WEIGHTS_NAME,
|
||||
SAFE_WEIGHTS_INDEX_NAME,
|
||||
@ -68,6 +71,7 @@ from .utils import (
|
||||
is_bitsandbytes_available,
|
||||
is_offline_mode,
|
||||
is_optimum_available,
|
||||
is_peft_available,
|
||||
is_remote_url,
|
||||
is_safetensors_available,
|
||||
is_torch_tpu_available,
|
||||
@ -123,6 +127,9 @@ if is_sagemaker_mp_enabled():
|
||||
else:
|
||||
IS_SAGEMAKER_MP_POST_1_10 = False
|
||||
|
||||
if is_peft_available():
|
||||
from .utils import find_adapter_config_file
|
||||
|
||||
|
||||
@contextmanager
|
||||
def no_init_weights(_enable=True):
|
||||
@ -1039,7 +1046,7 @@ class ModuleUtilsMixin:
|
||||
return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
|
||||
|
||||
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin):
|
||||
class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMixin, PeftAdapterMixin):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
@ -1738,6 +1745,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
safe_serialization: bool = False,
|
||||
variant: Optional[str] = None,
|
||||
token: Optional[Union[str, bool]] = None,
|
||||
save_peft_format: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@ -1780,6 +1788,10 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
token (`str` or `bool`, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
|
||||
the token generated when running `huggingface-cli login` (stored in `~/.huggingface`).
|
||||
save_peft_format (`bool`, *optional*, defaults to `True`):
|
||||
For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
|
||||
keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can
|
||||
disable this behaviours by setting `save_peft_format` to `False`.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
@ -1847,12 +1859,33 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if self._auto_class is not None:
|
||||
custom_object_save(self, save_directory, config=self.config)
|
||||
|
||||
_hf_peft_config_loaded = getattr(model_to_save, "_hf_peft_config_loaded", False)
|
||||
|
||||
# Save the config
|
||||
if is_main_process:
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
if not _hf_peft_config_loaded:
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
if self.can_generate():
|
||||
model_to_save.generation_config.save_pretrained(save_directory)
|
||||
|
||||
if _hf_peft_config_loaded:
|
||||
logger.info(
|
||||
"Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
|
||||
)
|
||||
state_dict = model_to_save.get_adapter_state_dict()
|
||||
|
||||
if save_peft_format:
|
||||
logger.info(
|
||||
"To match the expected format of the PEFT library, all keys of the state dict of adapters will be pre-pended with `base_model.model`."
|
||||
)
|
||||
peft_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
peft_state_dict[f"base_model.model.{key}"] = value
|
||||
state_dict = peft_state_dict
|
||||
|
||||
current_peft_config = self.peft_config[self.active_adapter()]
|
||||
current_peft_config.save_pretrained(save_directory)
|
||||
|
||||
# Save the model
|
||||
if state_dict is None:
|
||||
state_dict = model_to_save.state_dict()
|
||||
@ -1907,8 +1940,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
)
|
||||
|
||||
# Shard the model if it is too big.
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
if not _hf_peft_config_loaded:
|
||||
weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
||||
weights_name = _add_variant(weights_name, variant)
|
||||
else:
|
||||
weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
|
||||
|
||||
shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
|
||||
|
||||
@ -2295,6 +2331,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
subfolder = kwargs.pop("subfolder", "")
|
||||
commit_hash = kwargs.pop("_commit_hash", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
_adapter_model_path = kwargs.pop("_adapter_model_path", None)
|
||||
adapter_name = kwargs.pop("adapter_name", "default")
|
||||
|
||||
if is_fsdp_enabled():
|
||||
low_cpu_mem_usage = True
|
||||
@ -2323,6 +2361,29 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
" ignored."
|
||||
)
|
||||
|
||||
if is_peft_available() and _adapter_model_path is None:
|
||||
maybe_adapter_model_path = find_adapter_config_file(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
token=token,
|
||||
commit_hash=commit_hash,
|
||||
)
|
||||
elif is_peft_available() and _adapter_model_path is not None:
|
||||
maybe_adapter_model_path = _adapter_model_path
|
||||
else:
|
||||
maybe_adapter_model_path = None
|
||||
|
||||
has_adapter_config = maybe_adapter_model_path is not None
|
||||
|
||||
if has_adapter_config:
|
||||
if _adapter_model_path is not None:
|
||||
adapter_model_id = _adapter_model_path
|
||||
else:
|
||||
with open(maybe_adapter_model_path, "r", encoding="utf-8") as f:
|
||||
adapter_model_id = pretrained_model_name_or_path
|
||||
pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
|
||||
|
||||
# change device_map into a map if we passed an int, a str or a torch.device
|
||||
if isinstance(device_map, torch.device):
|
||||
device_map = {"": device_map}
|
||||
@ -3153,6 +3214,14 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if quantization_method_from_config == QuantizationMethod.GPTQ:
|
||||
model = quantizer.post_init_model(model)
|
||||
|
||||
if has_adapter_config:
|
||||
model.load_adapter(
|
||||
adapter_model_id,
|
||||
adapter_name=adapter_name,
|
||||
revision=revision,
|
||||
token=token,
|
||||
)
|
||||
|
||||
if output_loading_info:
|
||||
if loading_info is None:
|
||||
loading_info = {
|
||||
|
@ -15,13 +15,14 @@
|
||||
"""Factory function to build auto-model classes."""
|
||||
import copy
|
||||
import importlib
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
from ...dynamic_module_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
from ...utils import copy_func, logging, requires_backends
|
||||
from ...utils import copy_func, find_adapter_config_file, is_peft_available, logging, requires_backends
|
||||
from .configuration_auto import AutoConfig, model_type_to_module_name, replace_list_option_in_docstrings
|
||||
|
||||
|
||||
@ -469,6 +470,24 @@ class _BaseAutoModelClass:
|
||||
if token is not None:
|
||||
hub_kwargs["token"] = token
|
||||
|
||||
if is_peft_available():
|
||||
revision = kwargs.get("revision", None)
|
||||
subfolder = kwargs.get("subfolder", None)
|
||||
|
||||
maybe_adapter_path = find_adapter_config_file(
|
||||
pretrained_model_name_or_path,
|
||||
revision=revision,
|
||||
token=use_auth_token,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
if maybe_adapter_path is not None:
|
||||
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
||||
adapter_config = json.load(f)
|
||||
|
||||
kwargs["_adapter_model_path"] = pretrained_model_name_or_path
|
||||
pretrained_model_name_or_path = adapter_config["base_model_name_or_path"]
|
||||
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
kwargs_orig = copy.deepcopy(kwargs)
|
||||
# ensure not to pollute the config object with torch_dtype="auto" - since it's
|
||||
|
@ -34,8 +34,10 @@ from ..models.auto.tokenization_auto import TOKENIZER_MAPPING, AutoTokenizer
|
||||
from ..tokenization_utils import PreTrainedTokenizer
|
||||
from ..utils import (
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
find_adapter_config_file,
|
||||
is_kenlm_available,
|
||||
is_offline_mode,
|
||||
is_peft_available,
|
||||
is_pyctcdecode_available,
|
||||
is_tf_available,
|
||||
is_torch_available,
|
||||
@ -721,6 +723,21 @@ def pipeline(
|
||||
config = AutoConfig.from_pretrained(config, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||
hub_kwargs["_commit_hash"] = config._commit_hash
|
||||
elif config is None and isinstance(model, str):
|
||||
# Check for an adapter file in the model path if PEFT is available
|
||||
if is_peft_available():
|
||||
subfolder = hub_kwargs.get("subfolder", None)
|
||||
maybe_adapter_path = find_adapter_config_file(
|
||||
model,
|
||||
revision=revision,
|
||||
token=use_auth_token,
|
||||
subfolder=subfolder,
|
||||
)
|
||||
|
||||
if maybe_adapter_path is not None:
|
||||
with open(maybe_adapter_path, "r", encoding="utf-8") as f:
|
||||
adapter_config = json.load(f)
|
||||
model = adapter_config["base_model_name_or_path"]
|
||||
|
||||
config = AutoConfig.from_pretrained(model, _from_pipeline=task, **hub_kwargs, **model_kwargs)
|
||||
hub_kwargs["_commit_hash"] = config._commit_hash
|
||||
|
||||
|
@ -69,6 +69,7 @@ from .utils import (
|
||||
is_onnx_available,
|
||||
is_optimum_available,
|
||||
is_pandas_available,
|
||||
is_peft_available,
|
||||
is_phonemizer_available,
|
||||
is_pyctcdecode_available,
|
||||
is_pytesseract_available,
|
||||
@ -369,6 +370,16 @@ def require_torch(test_case):
|
||||
return unittest.skipUnless(is_torch_available(), "test requires PyTorch")(test_case)
|
||||
|
||||
|
||||
def require_peft(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires PEFT.
|
||||
|
||||
These tests are skipped when PEFT isn't installed.
|
||||
|
||||
"""
|
||||
return unittest.skipUnless(is_peft_available(), "test requires PEFT")(test_case)
|
||||
|
||||
|
||||
def require_torchvision(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires Torchvision.
|
||||
|
@ -396,7 +396,7 @@ class Trainer:
|
||||
)
|
||||
|
||||
# At this stage the model is already loaded
|
||||
if getattr(model, "is_quantized", False):
|
||||
if getattr(model, "is_quantized", False) and not getattr(model, "_hf_peft_config_loaded", False):
|
||||
if getattr(model, "_is_quantized_training_enabled", False):
|
||||
logger.info(
|
||||
"The model is quantized. To train this model you need to add additional modules"
|
||||
|
@ -179,13 +179,17 @@ from .import_utils import (
|
||||
requires_backends,
|
||||
torch_only_method,
|
||||
)
|
||||
from .peft_utils import (
|
||||
ADAPTER_CONFIG_NAME,
|
||||
ADAPTER_SAFE_WEIGHTS_NAME,
|
||||
ADAPTER_WEIGHTS_NAME,
|
||||
check_peft_version,
|
||||
find_adapter_config_file,
|
||||
)
|
||||
|
||||
|
||||
WEIGHTS_NAME = "pytorch_model.bin"
|
||||
WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
|
||||
ADAPTER_CONFIG_NAME = "adapter_config.json"
|
||||
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
|
||||
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
|
||||
TF2_WEIGHTS_NAME = "tf_model.h5"
|
||||
TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
|
||||
TF_WEIGHTS_NAME = "model.ckpt"
|
||||
|
@ -1001,6 +1001,11 @@ JIEBA_IMPORT_ERROR = """
|
||||
jieba`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
PEFT_IMPORT_ERROR = """
|
||||
{0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install
|
||||
peft`. Please note that you may need to restart your runtime after installation.
|
||||
"""
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
||||
@ -1034,6 +1039,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("decord", (is_decord_available, DECORD_IMPORT_ERROR)),
|
||||
("cython", (is_cython_available, CYTHON_IMPORT_ERROR)),
|
||||
("jieba", (is_jieba_available, JIEBA_IMPORT_ERROR)),
|
||||
("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
98
src/transformers/utils/peft_utils.py
Normal file
98
src/transformers/utils/peft_utils.py
Normal file
@ -0,0 +1,98 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from packaging import version
|
||||
|
||||
from .hub import cached_file
|
||||
from .import_utils import is_peft_available
|
||||
|
||||
|
||||
ADAPTER_CONFIG_NAME = "adapter_config.json"
|
||||
ADAPTER_WEIGHTS_NAME = "adapter_model.bin"
|
||||
ADAPTER_SAFE_WEIGHTS_NAME = "adapter_model.safetensors"
|
||||
|
||||
|
||||
def find_adapter_config_file(
|
||||
model_id: str,
|
||||
revision: str = None,
|
||||
subfolder: str = None,
|
||||
token: Optional[str] = None,
|
||||
commit_hash: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
r"""
|
||||
Simply checks if the model stored on the Hub or locally is an adapter model or not, return the path the the adapter
|
||||
config file if it is, None otherwise.
|
||||
|
||||
Args:
|
||||
model_id (`str`):
|
||||
The identifier of the model to look for, can be either a local path or an id to the repository on the Hub.
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
|
||||
<Tip>
|
||||
|
||||
To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>".
|
||||
|
||||
</Tip>
|
||||
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
|
||||
specify the folder name here.
|
||||
token (`str`, `optional`):
|
||||
Whether to use authentication token to load the remote folder. Userful to load private repositories that
|
||||
are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to cache it.
|
||||
"""
|
||||
adapter_cached_filename = None
|
||||
if os.path.isdir(model_id):
|
||||
list_remote_files = os.listdir(model_id)
|
||||
if ADAPTER_CONFIG_NAME in list_remote_files:
|
||||
adapter_cached_filename = os.path.join(model_id, ADAPTER_CONFIG_NAME)
|
||||
else:
|
||||
adapter_cached_filename = cached_file(
|
||||
model_id,
|
||||
ADAPTER_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=token,
|
||||
_commit_hash=commit_hash,
|
||||
subfolder=subfolder,
|
||||
_raise_exceptions_for_missing_entries=False,
|
||||
_raise_exceptions_for_connection_errors=False,
|
||||
)
|
||||
|
||||
return adapter_cached_filename
|
||||
|
||||
|
||||
def check_peft_version(min_version: str) -> None:
|
||||
r"""
|
||||
Checks if the version of PEFT is compatible.
|
||||
|
||||
Args:
|
||||
version (`str`):
|
||||
The version of PEFT to check against.
|
||||
"""
|
||||
if not is_peft_available():
|
||||
raise ValueError("PEFT is not installed. Please install it with `pip install peft`")
|
||||
|
||||
is_peft_version_compatible = version.parse(importlib.metadata.version("peft")) <= version.parse(min_version)
|
||||
|
||||
if not is_peft_version_compatible:
|
||||
raise ValueError(
|
||||
f"The version of PEFT you are using is not compatible, please use a version that is greater"
|
||||
f" than {min_version}"
|
||||
)
|
236
tests/peft_integration/test_peft_integration.py
Normal file
236
tests/peft_integration/test_peft_integration.py
Normal file
@ -0,0 +1,236 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoModelForCausalLM, OPTForCausalLM
|
||||
from transformers.testing_utils import require_peft, require_torch, require_torch_gpu, slow, torch_device
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@require_peft
|
||||
@require_torch
|
||||
class PeftTesterMixin:
|
||||
peft_test_model_ids = ("peft-internal-testing/tiny-OPTForCausalLM-lora",)
|
||||
transformers_test_model_ids = ("hf-internal-testing/tiny-random-OPTForCausalLM",)
|
||||
transformers_test_model_classes = (AutoModelForCausalLM, OPTForCausalLM)
|
||||
|
||||
|
||||
# TODO: run it with CI after PEFT release.
|
||||
@slow
|
||||
class PeftIntegrationTester(unittest.TestCase, PeftTesterMixin):
|
||||
"""
|
||||
A testing suite that makes sure that the PeftModel class is correctly integrated into the transformers library.
|
||||
"""
|
||||
|
||||
def _check_lora_correctly_converted(self, model):
|
||||
"""
|
||||
Utility method to check if the model has correctly adapters injected on it.
|
||||
"""
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
is_peft_loaded = False
|
||||
|
||||
for _, m in model.named_modules():
|
||||
if isinstance(m, BaseTunerLayer):
|
||||
is_peft_loaded = True
|
||||
break
|
||||
|
||||
return is_peft_loaded
|
||||
|
||||
def test_peft_from_pretrained(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model through `from_pretrained`.
|
||||
This checks if we pass a remote folder that contains an adapter config and adapter weights, it
|
||||
should correctly load a model that has adapters injected on it.
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
||||
self.assertTrue(peft_model._hf_peft_config_loaded)
|
||||
# dummy generation
|
||||
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
||||
|
||||
def test_peft_state_dict(self):
|
||||
"""
|
||||
Simple test that checks if the returned state dict of `get_adapter_state_dict()` method contains
|
||||
the expected keys.
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
state_dict = peft_model.get_adapter_state_dict()
|
||||
|
||||
for key in state_dict.keys():
|
||||
self.assertTrue("lora" in key)
|
||||
|
||||
def test_peft_save_pretrained(self):
|
||||
"""
|
||||
Test that checks various combinations of `save_pretrained` with a model that has adapters loaded
|
||||
on it. This checks if the saved model contains the expected files (adapter weights and adapter config).
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
peft_model.save_pretrained(tmpdirname)
|
||||
|
||||
self.assertTrue("adapter_model.bin" in os.listdir(tmpdirname))
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
||||
|
||||
self.assertTrue("config.json" not in os.listdir(tmpdirname))
|
||||
self.assertTrue("pytorch_model.bin" not in os.listdir(tmpdirname))
|
||||
|
||||
peft_model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
||||
self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
||||
|
||||
peft_model.save_pretrained(tmpdirname, safe_serialization=True)
|
||||
self.assertTrue("adapter_model.safetensors" in os.listdir(tmpdirname))
|
||||
self.assertTrue("adapter_config.json" in os.listdir(tmpdirname))
|
||||
|
||||
peft_model = transformers_class.from_pretrained(tmpdirname).to(torch_device)
|
||||
self.assertTrue(self._check_lora_correctly_converted(peft_model))
|
||||
|
||||
def test_peft_enable_disable_adapters(self):
|
||||
"""
|
||||
A test that checks if `enable_adapters` and `disable_adapters` methods work as expected.
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||
|
||||
for model_id in self.transformers_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
peft_config = LoraConfig(init_lora_weights=False)
|
||||
|
||||
peft_model.add_adapter(peft_config)
|
||||
|
||||
peft_logits = peft_model(dummy_input).logits
|
||||
|
||||
peft_model.disable_adapters()
|
||||
|
||||
peft_logits_disabled = peft_model(dummy_input).logits
|
||||
|
||||
peft_model.enable_adapters()
|
||||
|
||||
peft_logits_enabled = peft_model(dummy_input).logits
|
||||
|
||||
self.assertTrue(torch.allclose(peft_logits, peft_logits_enabled, atol=1e-12, rtol=1e-12))
|
||||
self.assertFalse(torch.allclose(peft_logits_enabled, peft_logits_disabled, atol=1e-12, rtol=1e-12))
|
||||
|
||||
def test_peft_add_adapter(self):
|
||||
"""
|
||||
Simple test that tests if `add_adapter` works as expected
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
|
||||
for model_id in self.transformers_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
peft_config = LoraConfig(init_lora_weights=False)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
self.assertTrue(self._check_lora_correctly_converted(model))
|
||||
# dummy generation
|
||||
_ = model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
||||
|
||||
def test_peft_add_multi_adapter(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model through `from_pretrained`. This test tests if
|
||||
add_adapter works as expected in multi-adapter setting.
|
||||
"""
|
||||
from peft import LoraConfig
|
||||
from peft.tuners.tuners_utils import BaseTunerLayer
|
||||
|
||||
dummy_input = torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device)
|
||||
|
||||
for model_id in self.transformers_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
is_peft_loaded = False
|
||||
model = transformers_class.from_pretrained(model_id).to(torch_device)
|
||||
|
||||
logits_original_model = model(dummy_input).logits
|
||||
|
||||
peft_config = LoraConfig(init_lora_weights=False)
|
||||
|
||||
model.add_adapter(peft_config)
|
||||
|
||||
logits_adapter_1 = model(dummy_input)
|
||||
|
||||
model.add_adapter(peft_config, adapter_name="adapter-2")
|
||||
|
||||
logits_adapter_2 = model(dummy_input)
|
||||
|
||||
for _, m in model.named_modules():
|
||||
if isinstance(m, BaseTunerLayer):
|
||||
is_peft_loaded = True
|
||||
break
|
||||
|
||||
self.assertTrue(is_peft_loaded)
|
||||
|
||||
# dummy generation
|
||||
_ = model.generate(input_ids=dummy_input)
|
||||
|
||||
model.set_adapter("default")
|
||||
self.assertTrue(model.active_adapter() == "default")
|
||||
|
||||
model.set_adapter("adapter-2")
|
||||
self.assertTrue(model.active_adapter() == "adapter-2")
|
||||
|
||||
# Logits comparison
|
||||
self.assertFalse(
|
||||
torch.allclose(logits_adapter_1.logits, logits_adapter_2.logits, atol=1e-6, rtol=1e-6)
|
||||
)
|
||||
self.assertFalse(torch.allclose(logits_original_model, logits_adapter_2.logits, atol=1e-6, rtol=1e-6))
|
||||
|
||||
@require_torch_gpu
|
||||
def test_peft_from_pretrained_kwargs(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model through `from_pretrained` + additional kwargs
|
||||
and see if the integraiton behaves as expected.
|
||||
"""
|
||||
for model_id in self.peft_test_model_ids:
|
||||
for transformers_class in self.transformers_test_model_classes:
|
||||
peft_model = transformers_class.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
|
||||
|
||||
module = peft_model.model.decoder.layers[0].self_attn.v_proj
|
||||
self.assertTrue(module.__class__.__name__ == "Linear8bitLt")
|
||||
self.assertTrue(peft_model.hf_device_map is not None)
|
||||
|
||||
# dummy generation
|
||||
_ = peft_model.generate(input_ids=torch.LongTensor([[0, 1, 2, 3, 4, 5, 6, 7]]).to(torch_device))
|
||||
|
||||
def test_peft_pipeline(self):
|
||||
"""
|
||||
Simple test that tests the basic usage of PEFT model + pipeline
|
||||
"""
|
||||
from transformers import pipeline
|
||||
|
||||
for model_id in self.peft_test_model_ids:
|
||||
pipe = pipeline("text-generation", model_id)
|
||||
_ = pipe("Hello")
|
Loading…
Reference in New Issue
Block a user