mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
[core
/ Quantization
] AWQ integration (#27045)
* working v1 * oops * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * fixup * oops * push * more changes * add docs * some fixes * fix copies * add v1 doc * added installation guide * relax constraints * revert * attempt llm-awq * oops * oops * fixup * raise error when incorrect cuda compute capability * nit * add instructions for llm-awq * fixup * fix copies * fixup and docs * change * few changes + add demo * add v1 tests * add autoawq in dockerfile * finalize * Update tests/quantization/autoawq/test_awq.py * fix test * fix * fix issue * Update src/transformers/integrations/awq.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/main_classes/quantization.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update docs/source/en/main_classes/quantization.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/integrations/awq.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Update src/transformers/integrations/awq.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add link to example script * Update docs/source/en/main_classes/quantization.md Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * add more content * add more details * add link to quantization docs * camel case + change backend class name * change to string * fixup * raise errors if libs not installed * change to `bits` and `group_size` * nit * nit * Apply suggestions from code review Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * disable training * address some comments and fix nits * fix * final nits and fix tests * adapt to our new runners * make fix-copies * Update src/transformers/utils/quantization_config.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/integrations/awq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * Update src/transformers/integrations/awq.py Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * move to top * add conversion test * final nit * add more elaborated test --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
82c7e87987
commit
ae093eef01
@ -55,6 +55,9 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
|
||||
# Add einops for additional model testing
|
||||
RUN python3 -m pip install --no-cache-dir einops
|
||||
|
||||
# Add autoawq for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir autoawq
|
||||
|
||||
# For bettertransformer + gptq
|
||||
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
|
||||
|
||||
|
@ -16,6 +16,97 @@ rendered properly in your Markdown viewer.
|
||||
|
||||
# Quantize 🤗 Transformers models
|
||||
|
||||
## AWQ integration
|
||||
|
||||
AWQ method has been introduced in the [*AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration* paper](https://arxiv.org/abs/2306.00978). With AWQ you can run models in 4-bit precision, while preserving its original quality (i.e. no performance degradation) with a superior throughput that other quantization methods presented below - reaching similar throughput as pure `float16` inference.
|
||||
|
||||
We now support inference with any AWQ model, meaning anyone can load and use AWQ weights that are pushed on the Hub or saved locally. Note that using AWQ requires to have access to a NVIDIA GPU. CPU inference is not supported yet.
|
||||
|
||||
### Quantizing a model
|
||||
|
||||
We advise users to look at different existing tools in the ecosystem to quantize their models with AWQ algorithm, such as:
|
||||
|
||||
- [`llm-awq`](https://github.com/mit-han-lab/llm-awq) from MIT Han Lab
|
||||
- [`autoawq`](https://github.com/casper-hansen/AutoAWQ) from [`casper-hansen`](https://github.com/casper-hansen)
|
||||
- Intel neural compressor from Intel - through [`optimum-intel`](https://huggingface.co/docs/optimum/main/en/intel/optimization_inc)
|
||||
|
||||
Many other tools might exist in the ecosystem, please feel free to open a PR to add them to the list.
|
||||
Currently the integration with 🤗 Transformers is only available for models that have been quantized using `autoawq` library and `llm-awq`. Most of the models quantized with `auto-awq` can be found under [`TheBloke`](https://huggingface.co/TheBloke) namespace of 🤗 Hub, and to quantize models with `llm-awq` please refer to the [`convert_to_hf.py`](https://github.com/mit-han-lab/llm-awq/blob/main/examples/convert_to_hf.py) script in the examples folder of [`llm-awq`](https://github.com/mit-han-lab/llm-awq/).
|
||||
|
||||
### Load a quantized model
|
||||
|
||||
You can load a quantized model from the Hub using the `from_pretrained` method. Make sure that the pushed weights are quantized, by checking that the attribute `quantization_config` is present in the model's configuration file (`configuration.json`). You can confirm that the model is quantized in the AWQ format by checking the field `quantization_config.quant_method` which should be set to `"awq"`. Note that loading the model will set other weights in `float16` by default for performance reasons. If you want to change that behavior, you can pass `torch_dtype` argument to `torch.float32` or `torch.bfloat16`. You can find in the sections below some example snippets and notebook.
|
||||
|
||||
## Example usage
|
||||
|
||||
First, you need to install [`autoawq`](https://github.com/casper-hansen/AutoAWQ) library
|
||||
|
||||
```bash
|
||||
pip install autoawq
|
||||
```
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "TheBloke/zephyr-7B-alpha-AWQ"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0")
|
||||
```
|
||||
|
||||
In case you first load your model on CPU, make sure to move it to your GPU device before using
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_id = "TheBloke/zephyr-7B-alpha-AWQ"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda:0")
|
||||
```
|
||||
|
||||
### Combining AWQ and Flash Attention
|
||||
|
||||
You can combine AWQ quantization with Flash Attention to get a model that is both quantized and faster. Simply load the model using `from_pretrained` and pass `use_flash_attention_2=True` argument.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
|
||||
```
|
||||
|
||||
### Benchmarks
|
||||
|
||||
We performed some speed, throughput and latency benchmarks using [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.
|
||||
|
||||
Note at that time of writing this documentation section, the available quantization methods were: `awq`, `gptq` and `bitsandbytes`.
|
||||
|
||||
The benchmark was run on a NVIDIA-A100 instance and the model used was [`TheBloke/Mistral-7B-v0.1-AWQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-AWQ) for the AWQ model, [`TheBloke/Mistral-7B-v0.1-GPTQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-GPTQ) for the GPTQ model. We also benchmarked it against `bitsandbytes` quantization methods and native `float16` model. Some results are shown below:
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/forward_memory_plot.png">
|
||||
</div>
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/generate_memory_plot.png">
|
||||
</div>
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/generate_throughput_plot.png">
|
||||
</div>
|
||||
|
||||
<div style="text-align: center">
|
||||
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/forward_latency_plot.png">
|
||||
</div>
|
||||
|
||||
You can find the full results together with packages versions in [this link](https://github.com/huggingface/optimum-benchmark/tree/main/examples/running-mistral).
|
||||
|
||||
From the results it appears that AWQ quantization method is the fastest quantization method for inference, text generation and among the lowest peak memory for text generation. However, AWQ seems to have the largest forward latency per batch size.
|
||||
|
||||
### Google colab demo
|
||||
|
||||
Check out how to use this integration throughout this [Google Colab demo](https://colab.research.google.com/drive/1HzZH89yAXJaZgwJDhQj9LqSBux932BvY)!
|
||||
|
||||
### AwqConfig
|
||||
|
||||
[[autodoc]] AwqConfig
|
||||
|
||||
## `AutoGPTQ` Integration
|
||||
|
||||
🤗 Transformers has integrated `optimum` API to perform GPTQ quantization on language models. You can load and quantize your model in 8, 4, 3 or even 2 bits without a big drop of performance and faster inference speed! This is supported by most GPU hardwares.
|
||||
|
@ -778,7 +778,7 @@ _import_structure = {
|
||||
"is_vision_available",
|
||||
"logging",
|
||||
],
|
||||
"utils.quantization_config": ["BitsAndBytesConfig", "GPTQConfig"],
|
||||
"utils.quantization_config": ["AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
|
||||
}
|
||||
|
||||
# sentencepiece-backed objects
|
||||
@ -4943,7 +4943,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
# bitsandbytes config
|
||||
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig
|
||||
from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig
|
||||
|
||||
try:
|
||||
if not is_sentencepiece_available():
|
||||
|
@ -17,6 +17,7 @@ from ..utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"awq": ["replace_with_awq_linear"],
|
||||
"bitsandbytes": [
|
||||
"get_keys_to_not_convert",
|
||||
"replace_8bit_linear",
|
||||
@ -77,6 +78,7 @@ _import_structure = {
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .awq import replace_with_awq_linear
|
||||
from .bitsandbytes import (
|
||||
get_keys_to_not_convert,
|
||||
replace_8bit_linear,
|
||||
|
104
src/transformers/integrations/awq.py
Normal file
104
src/transformers/integrations/awq.py
Normal file
@ -0,0 +1,104 @@
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"AWQ (Activation aware Weight Quantization) integration file"
|
||||
from ..utils import is_auto_awq_available, is_torch_available
|
||||
from ..utils.quantization_config import AwqBackendPackingMethod, AWQLinearVersion
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def replace_with_awq_linear(
|
||||
model,
|
||||
modules_to_not_convert=None,
|
||||
quantization_config=None,
|
||||
current_key_name=None,
|
||||
has_been_replaced=False,
|
||||
) -> bool:
|
||||
"""
|
||||
Public method that recursively replaces the Linear layers of the given model with AWQ quantized layers.
|
||||
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
|
||||
conversion has been successfull or not.
|
||||
|
||||
During the module replacement, we also infer the backend to use through the `quantization_config` object.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to convert, can be any `torch.nn.Module` instance.
|
||||
quantization_config (`AwqConfig`):
|
||||
The quantization config object that contains the quantization parameters.
|
||||
modules_to_not_convert (`list`, *optional*):
|
||||
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
|
||||
converted.
|
||||
current_key_name (`list`, *optional*):
|
||||
A list that contains the current key name. This is used for recursion and should not be passed by the user.
|
||||
has_been_replaced (`bool`, *optional*):
|
||||
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
|
||||
should not be passed by the user.
|
||||
"""
|
||||
if modules_to_not_convert is None:
|
||||
modules_to_not_convert = []
|
||||
|
||||
backend = quantization_config.backend
|
||||
|
||||
if not is_auto_awq_available():
|
||||
raise ValueError(
|
||||
"AWQ (either `autoawq` or `llmawq`) is not available. Please install it with `pip install autoawq` or check out the installation guide in https://github.com/mit-han-lab/llm-awq"
|
||||
)
|
||||
|
||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
|
||||
elif backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
from awq.quantize.qmodule import WQLinear
|
||||
|
||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
target_cls = WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV
|
||||
else:
|
||||
target_cls = WQLinear
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
model._modules[name] = target_cls(
|
||||
w_bit=quantization_config.bits,
|
||||
group_size=quantization_config.group_size,
|
||||
in_features=in_features,
|
||||
out_features=out_features,
|
||||
bias=module.bias is not None,
|
||||
dev=module.weight.device,
|
||||
)
|
||||
has_been_replaced = True
|
||||
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = replace_with_awq_linear(
|
||||
module,
|
||||
modules_to_not_convert=modules_to_not_convert,
|
||||
current_key_name=current_key_name,
|
||||
quantization_config=quantization_config,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
@ -70,6 +70,7 @@ from .utils import (
|
||||
extract_commit_hash,
|
||||
has_file,
|
||||
is_accelerate_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_bitsandbytes_available,
|
||||
is_flash_attn_2_available,
|
||||
@ -90,7 +91,7 @@ from .utils.import_utils import (
|
||||
is_torch_fx_proxy,
|
||||
is_torchdynamo_compiling,
|
||||
)
|
||||
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod
|
||||
from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantizationMethod
|
||||
from .utils.versions import require_version_core
|
||||
|
||||
|
||||
@ -2674,6 +2675,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
quantization_config, "quant_method", QuantizationMethod.BITS_AND_BYTES
|
||||
)
|
||||
|
||||
if quantization_method_from_args == QuantizationMethod.AWQ:
|
||||
raise ValueError(
|
||||
"You cannot pass an `AwqConfig` when loading a model as you can only use AWQ models"
|
||||
" for inference. To quantize transformers models with AWQ algorithm, please refer to our"
|
||||
" quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization "
|
||||
)
|
||||
|
||||
if quantization_config is None and (load_in_8bit or load_in_4bit):
|
||||
quantization_method_from_args = QuantizationMethod.BITS_AND_BYTES
|
||||
quantization_config, kwargs = BitsAndBytesConfig.from_dict(
|
||||
@ -2805,6 +2813,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
|
||||
|
||||
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())
|
||||
elif quantization_method_from_config == QuantizationMethod.AWQ:
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU is required to run AWQ quantized model.")
|
||||
|
||||
if not is_auto_awq_available():
|
||||
raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
|
||||
|
||||
if device_map is None:
|
||||
logger.warning(
|
||||
"You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
|
||||
"your model on a GPU device in order to run your model."
|
||||
)
|
||||
elif device_map is not None:
|
||||
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
|
||||
raise ValueError(
|
||||
"You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
|
||||
" This is not supported. Please remove the CPU or disk device from the device_map."
|
||||
)
|
||||
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
|
||||
|
||||
# Force-set to `True` for more mem efficiency
|
||||
if low_cpu_mem_usage is None:
|
||||
low_cpu_mem_usage = True
|
||||
|
||||
if (
|
||||
is_8bit_serializable
|
||||
@ -3265,6 +3303,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if quantization_method_from_config == QuantizationMethod.GPTQ:
|
||||
model = quantizer.convert_model(model)
|
||||
model._is_quantized_training_enabled = True
|
||||
elif quantization_method_from_config == QuantizationMethod.AWQ:
|
||||
from .integrations import get_keys_to_not_convert, replace_with_awq_linear
|
||||
|
||||
modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
|
||||
if quantization_config is None:
|
||||
quantization_config = AwqConfig.from_dict(config.quantization_config)
|
||||
|
||||
model, has_been_replaced = replace_with_awq_linear(
|
||||
model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
|
||||
)
|
||||
model._is_quantized_training_enabled = False
|
||||
|
||||
if not has_been_replaced:
|
||||
logger.warning(
|
||||
"You are loading an AWQ model but no linear modules were found in your model."
|
||||
" Please double check your model architecture, or submit an issue on github if you think this is"
|
||||
" a bug."
|
||||
)
|
||||
|
||||
if quantization_method_from_config is not None:
|
||||
model.quantization_method = quantization_method_from_config
|
||||
|
@ -52,6 +52,7 @@ from .integrations.deepspeed import is_deepspeed_available
|
||||
from .utils import (
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bs4_available,
|
||||
@ -963,6 +964,13 @@ def require_auto_gptq(test_case):
|
||||
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
|
||||
|
||||
|
||||
def require_auto_awq(test_case):
|
||||
"""
|
||||
Decorator for auto_awq dependency
|
||||
"""
|
||||
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
|
||||
|
||||
|
||||
def require_phonemizer(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires phonemizer
|
||||
|
@ -104,6 +104,7 @@ from .import_utils import (
|
||||
get_torch_version,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bs4_available,
|
||||
|
@ -107,6 +107,8 @@ _onnx_available = _is_package_available("onnx")
|
||||
_openai_available = _is_package_available("openai")
|
||||
_optimum_available = _is_package_available("optimum")
|
||||
_auto_gptq_available = _is_package_available("auto_gptq")
|
||||
# `importlib.metadata.version` doesn't work with `awq`
|
||||
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
||||
_pandas_available = _is_package_available("pandas")
|
||||
_peft_available = _is_package_available("peft")
|
||||
_phonemizer_available = _is_package_available("phonemizer")
|
||||
@ -675,6 +677,10 @@ def is_optimum_available():
|
||||
return _optimum_available
|
||||
|
||||
|
||||
def is_auto_awq_available():
|
||||
return _auto_awq_available
|
||||
|
||||
|
||||
def is_auto_gptq_available():
|
||||
return _auto_gptq_available
|
||||
|
||||
|
@ -37,6 +37,17 @@ logger = logging.get_logger(__name__)
|
||||
class QuantizationMethod(str, Enum):
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
GEMM = "gemm"
|
||||
GEMV = "gemv"
|
||||
|
||||
|
||||
class AwqBackendPackingMethod(str, Enum):
|
||||
AUTOAWQ = "autoawq"
|
||||
LLMAWQ = "llm-awq"
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -418,3 +429,67 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
f"""dataset needs to be either a list of string or a value in
|
||||
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AwqConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using `auto-awq` library awq quantization relying on auto_awq backend.
|
||||
|
||||
Args:
|
||||
bits (`int`, *optional*, defaults to 4):
|
||||
The number of bits to quantize to.
|
||||
group_size (`int`, *optional*, defaults to 128):
|
||||
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
|
||||
zero_point (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use zero point quantization.
|
||||
version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
|
||||
The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
|
||||
GEMV is better (e.g. < 8 )
|
||||
backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
|
||||
The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
|
||||
that quantize their own models using `llm-awq` library.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bits: int = 4,
|
||||
group_size: int = 128,
|
||||
zero_point: bool = True,
|
||||
version: AWQLinearVersion = AWQLinearVersion.GEMM,
|
||||
backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.AWQ
|
||||
|
||||
self.bits = bits
|
||||
self.group_size = group_size
|
||||
self.zero_point = zero_point
|
||||
self.version = version
|
||||
self.backend = backend
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct
|
||||
"""
|
||||
if not torch.cuda.is_available():
|
||||
raise ValueError("AWQ is only available on GPU")
|
||||
|
||||
if self.backend not in [AwqBackendPackingMethod.AUTOAWQ, AwqBackendPackingMethod.LLMAWQ]:
|
||||
raise ValueError(
|
||||
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
|
||||
)
|
||||
|
||||
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
|
||||
raise ValueError(
|
||||
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
|
||||
)
|
||||
|
||||
if self.backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
compute_capability = torch.cuda.get_device_capability()
|
||||
major, minor = compute_capability
|
||||
if major < 8:
|
||||
raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0")
|
||||
|
0
tests/quantization/autoawq/__init__.py
Normal file
0
tests/quantization/autoawq/__init__.py
Normal file
221
tests/quantization/autoawq/test_awq.py
Normal file
221
tests/quantization/autoawq/test_awq.py
Normal file
@ -0,0 +1,221 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqConfig, OPTForCausalLM
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_auto_awq,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class AwqConfigTest(unittest.TestCase):
|
||||
def test_wrong_backend(self):
|
||||
"""
|
||||
Simple test that checks if a user passes a wrong backend an error is raised
|
||||
"""
|
||||
# This should work fine
|
||||
_ = AwqConfig(bits=4)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
AwqConfig(bits=4, backend="")
|
||||
|
||||
# LLMAWQ does not work on a T4
|
||||
with self.assertRaises(ValueError):
|
||||
AwqConfig(bits=4, backend="llm-awq")
|
||||
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
|
||||
"""
|
||||
quantization_config = AwqConfig(bits=4)
|
||||
config_to_dict = quantization_config.to_dict()
|
||||
|
||||
for key in config_to_dict:
|
||||
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
|
||||
|
||||
def test_from_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
|
||||
"""
|
||||
dict = {"bits": 2, "zero_point": False, "backend": "autoawq"}
|
||||
quantization_config = AwqConfig.from_dict(dict)
|
||||
|
||||
self.assertEqual(dict["bits"], quantization_config.bits)
|
||||
self.assertEqual(dict["zero_point"], quantization_config.zero_point)
|
||||
self.assertEqual(dict["backend"], quantization_config.backend)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_auto_awq
|
||||
@require_accelerate
|
||||
class AwqTest(unittest.TestCase):
|
||||
# TODO: @younesbelkada change it to `TheBloke/Mistral-7B-v0.1-AWQ` in the future
|
||||
model_name = "ybelkada/test-mistral-7b-v0.1-awq"
|
||||
dummy_transformers_model_name = "bigscience/bloom-560m"
|
||||
|
||||
input_text = "Hello my name is"
|
||||
|
||||
EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish"
|
||||
EXPECTED_OUTPUT_BF16 = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Exercise and Sport Science with a"
|
||||
|
||||
device_map = "cuda"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
device_map=cls.device_map,
|
||||
)
|
||||
|
||||
def test_quantized_model_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly
|
||||
"""
|
||||
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
|
||||
|
||||
from transformers.integrations.awq import replace_with_awq_linear
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
|
||||
quantization_config = AwqConfig(bits=4)
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
|
||||
nb_linears = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
nb_linears += 1
|
||||
|
||||
model, _ = replace_with_awq_linear(model, quantization_config=quantization_config)
|
||||
nb_awq_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, (WQLinear_GEMM, WQLinear_GEMV)):
|
||||
nb_awq_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears, nb_awq_linear)
|
||||
|
||||
# Try with `modules_not_to_convert`
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
|
||||
model, _ = replace_with_awq_linear(
|
||||
model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"]
|
||||
)
|
||||
nb_awq_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, (WQLinear_GEMM, WQLinear_GEMV)):
|
||||
nb_awq_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 1, nb_awq_linear)
|
||||
|
||||
def test_quantized_model(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_quantized_model_bf16(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with bf16
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.bfloat16).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16)
|
||||
|
||||
def test_quantized_model_no_device_map(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name).to(torch_device)
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_raise_quantization(self):
|
||||
"""
|
||||
Simple test that checks if one passes a quantization config to quantize a model, it raises an error
|
||||
"""
|
||||
quantization_config = AwqConfig(bits=4)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
_ = AutoModelForCausalLM.from_pretrained(
|
||||
self.dummy_transformers_model_name, quantization_config=quantization_config
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
str(context.exception),
|
||||
"You cannot pass an `AwqConfig` when loading a model as you can only use AWQ models for inference. To quantize transformers models with AWQ algorithm, please refer to our quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization ",
|
||||
)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
|
||||
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1, 2, 3})
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
Loading…
Reference in New Issue
Block a user