[core / Quantization ] AWQ integration (#27045)

* working v1

* oops

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* fixup

* oops

* push

* more changes

* add docs

* some fixes

* fix copies

* add v1 doc

* added installation guide

* relax constraints

* revert

* attempt llm-awq

* oops

* oops

* fixup

* raise error when incorrect cuda compute capability

* nit

* add instructions for llm-awq

* fixup

* fix copies

* fixup and docs

* change

* few changes + add demo

* add v1 tests

* add autoawq in dockerfile

* finalize

* Update tests/quantization/autoawq/test_awq.py

* fix test

* fix

* fix issue

* Update src/transformers/integrations/awq.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/main_classes/quantization.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update docs/source/en/main_classes/quantization.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/integrations/awq.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Update src/transformers/integrations/awq.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add link to example script

* Update docs/source/en/main_classes/quantization.md

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* add more content

* add more details

* add link to quantization docs

* camel case + change backend class name

* change to string

* fixup

* raise errors if libs not installed

* change to `bits` and `group_size`

* nit

* nit

* Apply suggestions from code review

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* disable training

* address some comments and fix nits

* fix

* final nits and fix tests

* adapt to our new runners

* make fix-copies

* Update src/transformers/utils/quantization_config.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/integrations/awq.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/integrations/awq.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* move to top

* add conversion test

* final nit

* add more elaborated test

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Younes Belkada 2023-11-01 09:06:31 +01:00 committed by GitHub
parent 82c7e87987
commit ae093eef01
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 571 additions and 3 deletions

View File

@ -55,6 +55,9 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
# Add einops for additional model testing
RUN python3 -m pip install --no-cache-dir einops
# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir autoawq
# For bettertransformer + gptq
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum

View File

@ -16,6 +16,97 @@ rendered properly in your Markdown viewer.
# Quantize 🤗 Transformers models
## AWQ integration
AWQ method has been introduced in the [*AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration* paper](https://arxiv.org/abs/2306.00978). With AWQ you can run models in 4-bit precision, while preserving its original quality (i.e. no performance degradation) with a superior throughput that other quantization methods presented below - reaching similar throughput as pure `float16` inference.
We now support inference with any AWQ model, meaning anyone can load and use AWQ weights that are pushed on the Hub or saved locally. Note that using AWQ requires to have access to a NVIDIA GPU. CPU inference is not supported yet.
### Quantizing a model
We advise users to look at different existing tools in the ecosystem to quantize their models with AWQ algorithm, such as:
- [`llm-awq`](https://github.com/mit-han-lab/llm-awq) from MIT Han Lab
- [`autoawq`](https://github.com/casper-hansen/AutoAWQ) from [`casper-hansen`](https://github.com/casper-hansen)
- Intel neural compressor from Intel - through [`optimum-intel`](https://huggingface.co/docs/optimum/main/en/intel/optimization_inc)
Many other tools might exist in the ecosystem, please feel free to open a PR to add them to the list.
Currently the integration with 🤗 Transformers is only available for models that have been quantized using `autoawq` library and `llm-awq`. Most of the models quantized with `auto-awq` can be found under [`TheBloke`](https://huggingface.co/TheBloke) namespace of 🤗 Hub, and to quantize models with `llm-awq` please refer to the [`convert_to_hf.py`](https://github.com/mit-han-lab/llm-awq/blob/main/examples/convert_to_hf.py) script in the examples folder of [`llm-awq`](https://github.com/mit-han-lab/llm-awq/).
### Load a quantized model
You can load a quantized model from the Hub using the `from_pretrained` method. Make sure that the pushed weights are quantized, by checking that the attribute `quantization_config` is present in the model's configuration file (`configuration.json`). You can confirm that the model is quantized in the AWQ format by checking the field `quantization_config.quant_method` which should be set to `"awq"`. Note that loading the model will set other weights in `float16` by default for performance reasons. If you want to change that behavior, you can pass `torch_dtype` argument to `torch.float32` or `torch.bfloat16`. You can find in the sections below some example snippets and notebook.
## Example usage
First, you need to install [`autoawq`](https://github.com/casper-hansen/AutoAWQ) library
```bash
pip install autoawq
```
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "TheBloke/zephyr-7B-alpha-AWQ"
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda:0")
```
In case you first load your model on CPU, make sure to move it to your GPU device before using
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "TheBloke/zephyr-7B-alpha-AWQ"
model = AutoModelForCausalLM.from_pretrained(model_id).to("cuda:0")
```
### Combining AWQ and Flash Attention
You can combine AWQ quantization with Flash Attention to get a model that is both quantized and faster. Simply load the model using `from_pretrained` and pass `use_flash_attention_2=True` argument.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("TheBloke/zephyr-7B-alpha-AWQ", use_flash_attention_2=True, device_map="cuda:0")
```
### Benchmarks
We performed some speed, throughput and latency benchmarks using [`optimum-benchmark`](https://github.com/huggingface/optimum-benchmark) library.
Note at that time of writing this documentation section, the available quantization methods were: `awq`, `gptq` and `bitsandbytes`.
The benchmark was run on a NVIDIA-A100 instance and the model used was [`TheBloke/Mistral-7B-v0.1-AWQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-AWQ) for the AWQ model, [`TheBloke/Mistral-7B-v0.1-GPTQ`](https://huggingface.co/TheBloke/Mistral-7B-v0.1-GPTQ) for the GPTQ model. We also benchmarked it against `bitsandbytes` quantization methods and native `float16` model. Some results are shown below:
<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/forward_memory_plot.png">
</div>
<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/generate_memory_plot.png">
</div>
<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/generate_throughput_plot.png">
</div>
<div style="text-align: center">
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/quantization/forward_latency_plot.png">
</div>
You can find the full results together with packages versions in [this link](https://github.com/huggingface/optimum-benchmark/tree/main/examples/running-mistral).
From the results it appears that AWQ quantization method is the fastest quantization method for inference, text generation and among the lowest peak memory for text generation. However, AWQ seems to have the largest forward latency per batch size.
### Google colab demo
Check out how to use this integration throughout this [Google Colab demo](https://colab.research.google.com/drive/1HzZH89yAXJaZgwJDhQj9LqSBux932BvY)!
### AwqConfig
[[autodoc]] AwqConfig
## `AutoGPTQ` Integration
🤗 Transformers has integrated `optimum` API to perform GPTQ quantization on language models. You can load and quantize your model in 8, 4, 3 or even 2 bits without a big drop of performance and faster inference speed! This is supported by most GPU hardwares.

View File

@ -778,7 +778,7 @@ _import_structure = {
"is_vision_available",
"logging",
],
"utils.quantization_config": ["BitsAndBytesConfig", "GPTQConfig"],
"utils.quantization_config": ["AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
}
# sentencepiece-backed objects
@ -4943,7 +4943,7 @@ if TYPE_CHECKING:
)
# bitsandbytes config
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig
from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig
try:
if not is_sentencepiece_available():

View File

@ -17,6 +17,7 @@ from ..utils import _LazyModule
_import_structure = {
"awq": ["replace_with_awq_linear"],
"bitsandbytes": [
"get_keys_to_not_convert",
"replace_8bit_linear",
@ -77,6 +78,7 @@ _import_structure = {
}
if TYPE_CHECKING:
from .awq import replace_with_awq_linear
from .bitsandbytes import (
get_keys_to_not_convert,
replace_8bit_linear,

View File

@ -0,0 +1,104 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"AWQ (Activation aware Weight Quantization) integration file"
from ..utils import is_auto_awq_available, is_torch_available
from ..utils.quantization_config import AwqBackendPackingMethod, AWQLinearVersion
if is_torch_available():
import torch.nn as nn
def replace_with_awq_linear(
model,
modules_to_not_convert=None,
quantization_config=None,
current_key_name=None,
has_been_replaced=False,
) -> bool:
"""
Public method that recursively replaces the Linear layers of the given model with AWQ quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successfull or not.
During the module replacement, we also infer the backend to use through the `quantization_config` object.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`AwqConfig`):
The quantization config object that contains the quantization parameters.
modules_to_not_convert (`list`, *optional*):
A list of modules to not convert. If a module name is in the list (e.g. `lm_head`), it will not be
converted.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
if modules_to_not_convert is None:
modules_to_not_convert = []
backend = quantization_config.backend
if not is_auto_awq_available():
raise ValueError(
"AWQ (either `autoawq` or `llmawq`) is not available. Please install it with `pip install autoawq` or check out the installation guide in https://github.com/mit-han-lab/llm-awq"
)
if backend == AwqBackendPackingMethod.AUTOAWQ:
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
elif backend == AwqBackendPackingMethod.LLMAWQ:
from awq.quantize.qmodule import WQLinear
if backend == AwqBackendPackingMethod.AUTOAWQ:
target_cls = WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV
else:
target_cls = WQLinear
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
in_features = module.in_features
out_features = module.out_features
model._modules[name] = target_cls(
w_bit=quantization_config.bits,
group_size=quantization_config.group_size,
in_features=in_features,
out_features=out_features,
bias=module.bias is not None,
dev=module.weight.device,
)
has_been_replaced = True
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_awq_linear(
module,
modules_to_not_convert=modules_to_not_convert,
current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced

View File

@ -70,6 +70,7 @@ from .utils import (
extract_commit_hash,
has_file,
is_accelerate_available,
is_auto_awq_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_flash_attn_2_available,
@ -90,7 +91,7 @@ from .utils.import_utils import (
is_torch_fx_proxy,
is_torchdynamo_compiling,
)
from .utils.quantization_config import BitsAndBytesConfig, GPTQConfig, QuantizationMethod
from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig, QuantizationMethod
from .utils.versions import require_version_core
@ -2674,6 +2675,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
quantization_config, "quant_method", QuantizationMethod.BITS_AND_BYTES
)
if quantization_method_from_args == QuantizationMethod.AWQ:
raise ValueError(
"You cannot pass an `AwqConfig` when loading a model as you can only use AWQ models"
" for inference. To quantize transformers models with AWQ algorithm, please refer to our"
" quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization "
)
if quantization_config is None and (load_in_8bit or load_in_4bit):
quantization_method_from_args = QuantizationMethod.BITS_AND_BYTES
quantization_config, kwargs = BitsAndBytesConfig.from_dict(
@ -2805,6 +2813,36 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())
elif quantization_method_from_config == QuantizationMethod.AWQ:
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run AWQ quantized model.")
if not is_auto_awq_available():
raise ImportError("Loading an AWQ quantized model requires auto-awq library (`pip install autoawq`)")
if not is_accelerate_available():
raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
if device_map is None:
logger.warning(
"You have loaded an AWQ model on CPU and have a CUDA device available, make sure to set "
"your model on a GPU device in order to run your model."
)
elif device_map is not None:
if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
raise ValueError(
"You are attempting to load an AWQ model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
if torch_dtype is None:
torch_dtype = torch.float16
else:
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
# Force-set to `True` for more mem efficiency
if low_cpu_mem_usage is None:
low_cpu_mem_usage = True
if (
is_8bit_serializable
@ -3265,6 +3303,25 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if quantization_method_from_config == QuantizationMethod.GPTQ:
model = quantizer.convert_model(model)
model._is_quantized_training_enabled = True
elif quantization_method_from_config == QuantizationMethod.AWQ:
from .integrations import get_keys_to_not_convert, replace_with_awq_linear
modules_to_not_convert = get_keys_to_not_convert(model)
if quantization_config is None:
quantization_config = AwqConfig.from_dict(config.quantization_config)
model, has_been_replaced = replace_with_awq_linear(
model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
)
model._is_quantized_training_enabled = False
if not has_been_replaced:
logger.warning(
"You are loading an AWQ model but no linear modules were found in your model."
" Please double check your model architecture, or submit an issue on github if you think this is"
" a bug."
)
if quantization_method_from_config is not None:
model.quantization_method = quantization_method_from_config

View File

@ -52,6 +52,7 @@ from .integrations.deepspeed import is_deepspeed_available
from .utils import (
is_accelerate_available,
is_apex_available,
is_auto_awq_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_bs4_available,
@ -963,6 +964,13 @@ def require_auto_gptq(test_case):
return unittest.skipUnless(is_auto_gptq_available(), "test requires auto-gptq")(test_case)
def require_auto_awq(test_case):
"""
Decorator for auto_awq dependency
"""
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
def require_phonemizer(test_case):
"""
Decorator marking a test that requires phonemizer

View File

@ -104,6 +104,7 @@ from .import_utils import (
get_torch_version,
is_accelerate_available,
is_apex_available,
is_auto_awq_available,
is_auto_gptq_available,
is_bitsandbytes_available,
is_bs4_available,

View File

@ -107,6 +107,8 @@ _onnx_available = _is_package_available("onnx")
_openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_auto_gptq_available = _is_package_available("auto_gptq")
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_pandas_available = _is_package_available("pandas")
_peft_available = _is_package_available("peft")
_phonemizer_available = _is_package_available("phonemizer")
@ -675,6 +677,10 @@ def is_optimum_available():
return _optimum_available
def is_auto_awq_available():
return _auto_awq_available
def is_auto_gptq_available():
return _auto_gptq_available

View File

@ -37,6 +37,17 @@ logger = logging.get_logger(__name__)
class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq"
AWQ = "awq"
class AWQLinearVersion(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
class AwqBackendPackingMethod(str, Enum):
AUTOAWQ = "autoawq"
LLMAWQ = "llm-awq"
@dataclass
@ -418,3 +429,67 @@ class GPTQConfig(QuantizationConfigMixin):
f"""dataset needs to be either a list of string or a value in
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
)
@dataclass
class AwqConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `auto-awq` library awq quantization relying on auto_awq backend.
Args:
bits (`int`, *optional*, defaults to 4):
The number of bits to quantize to.
group_size (`int`, *optional*, defaults to 128):
The group size to use for quantization. Recommended value is 128 and -1 uses per-column quantization.
zero_point (`bool`, *optional*, defaults to `True`):
Whether to use zero point quantization.
version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
GEMV is better (e.g. < 8 )
backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
that quantize their own models using `llm-awq` library.
"""
def __init__(
self,
bits: int = 4,
group_size: int = 128,
zero_point: bool = True,
version: AWQLinearVersion = AWQLinearVersion.GEMM,
backend: AwqBackendPackingMethod = AwqBackendPackingMethod.AUTOAWQ,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
self.bits = bits
self.group_size = group_size
self.zero_point = zero_point
self.version = version
self.backend = backend
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct
"""
if not torch.cuda.is_available():
raise ValueError("AWQ is only available on GPU")
if self.backend not in [AwqBackendPackingMethod.AUTOAWQ, AwqBackendPackingMethod.LLMAWQ]:
raise ValueError(
f"Only supported quantization backends in {AwqBackendPackingMethod.AUTOAWQ} and {AwqBackendPackingMethod.LLMAWQ} - not recognized backend {self.backend}"
)
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
raise ValueError(
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
)
if self.backend == AwqBackendPackingMethod.LLMAWQ:
compute_capability = torch.cuda.get_device_capability()
major, minor = compute_capability
if major < 8:
raise ValueError("LLM-AWQ backend is only supported on GPUs with compute capability >= 8.0")

View File

View File

@ -0,0 +1,221 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AwqConfig, OPTForCausalLM
from transformers.testing_utils import (
require_accelerate,
require_auto_awq,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class AwqConfigTest(unittest.TestCase):
def test_wrong_backend(self):
"""
Simple test that checks if a user passes a wrong backend an error is raised
"""
# This should work fine
_ = AwqConfig(bits=4)
with self.assertRaises(ValueError):
AwqConfig(bits=4, backend="")
# LLMAWQ does not work on a T4
with self.assertRaises(ValueError):
AwqConfig(bits=4, backend="llm-awq")
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = AwqConfig(bits=4)
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {"bits": 2, "zero_point": False, "backend": "autoawq"}
quantization_config = AwqConfig.from_dict(dict)
self.assertEqual(dict["bits"], quantization_config.bits)
self.assertEqual(dict["zero_point"], quantization_config.zero_point)
self.assertEqual(dict["backend"], quantization_config.backend)
@slow
@require_torch_gpu
@require_auto_awq
@require_accelerate
class AwqTest(unittest.TestCase):
# TODO: @younesbelkada change it to `TheBloke/Mistral-7B-v0.1-AWQ` in the future
model_name = "ybelkada/test-mistral-7b-v0.1-awq"
dummy_transformers_model_name = "bigscience/bloom-560m"
input_text = "Hello my name is"
EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Journalism and minoring in Spanish"
EXPECTED_OUTPUT_BF16 = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am a junior and I am majoring in Exercise and Sport Science with a"
device_map = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
device_map=cls.device_map,
)
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.integrations.awq import replace_with_awq_linear
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
quantization_config = AwqConfig(bits=4)
with init_empty_weights():
model = OPTForCausalLM(config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model, _ = replace_with_awq_linear(model, quantization_config=quantization_config)
nb_awq_linear = 0
for module in model.modules():
if isinstance(module, (WQLinear_GEMM, WQLinear_GEMV)):
nb_awq_linear += 1
self.assertEqual(nb_linears, nb_awq_linear)
# Try with `modules_not_to_convert`
with init_empty_weights():
model = OPTForCausalLM(config)
model, _ = replace_with_awq_linear(
model, quantization_config=quantization_config, modules_to_not_convert=["lm_head"]
)
nb_awq_linear = 0
for module in model.modules():
if isinstance(module, (WQLinear_GEMM, WQLinear_GEMV)):
nb_awq_linear += 1
self.assertEqual(nb_linears - 1, nb_awq_linear)
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_quantized_model_bf16(self):
"""
Simple test that checks if the quantized model is working properly with bf16
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype=torch.bfloat16).to(
torch_device
)
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16)
def test_quantized_model_no_device_map(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name).to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_raise_quantization(self):
"""
Simple test that checks if one passes a quantization config to quantize a model, it raises an error
"""
quantization_config = AwqConfig(bits=4)
with self.assertRaises(ValueError) as context:
_ = AutoModelForCausalLM.from_pretrained(
self.dummy_transformers_model_name, quantization_config=quantization_config
)
self.assertEqual(
str(context.exception),
"You cannot pass an `AwqConfig` when loading a model as you can only use AWQ models for inference. To quantize transformers models with AWQ algorithm, please refer to our quantization docs: https://huggingface.co/docs/transformers/main_classes/quantization ",
)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1, 2, 3})
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)