mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 11:11:05 +06:00
Add AutoRound quantization support (#37393)
* add auto-round support * Update src/transformers/quantizers/auto.py Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> * fix style issue Signed-off-by: wenhuach <wenhuach87@gmail.com> * tiny change * tiny change * refine ut and doc * revert unnecessary change * tiny change * try to fix style issue * try to fix style issue * try to fix style issue * try to fix style issue * try to fix style issue * try to fix style issue * try to fix style issue * fix doc issue * Update tests/quantization/autoround/test_auto_round.py * fix comments * Update tests/quantization/autoround/test_auto_round.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update tests/quantization/autoround/test_auto_round.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * update doc * Update src/transformers/quantizers/quantizer_auto_round.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * update * update * fix * try to fix style issue * Update src/transformers/quantizers/auto.py Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update docs/source/en/quantization/auto_round.md Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update docs/source/en/quantization/auto_round.md Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * Update docs/source/en/quantization/auto_round.md Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> * update * fix style issue * update doc * update doc * Refine the doc * refine doc * revert one change * set sym to True by default * Enhance the unit test's robustness. * update * add torch dtype * tiny change * add awq convert test * fix typo * update * fix packing format issue * use one gpu --------- Signed-off-by: wenhuach <wenhuach87@gmail.com> Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> Co-authored-by: Shen, Haihao <haihao.shen@intel.com>
This commit is contained in:
parent
9608908639
commit
b3492ff9f7
@ -84,6 +84,9 @@ RUN python3 -m pip install --no-cache-dir compressed-tensors
|
||||
# Add AMD Quark for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir amd-quark
|
||||
|
||||
# Add AutoRound for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir "auto-round>=0.5.0"
|
||||
|
||||
# Add transformers in editable mode
|
||||
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch]
|
||||
|
||||
|
@ -92,3 +92,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
## QuarkConfig
|
||||
|
||||
[[autodoc]] QuarkConfig
|
||||
|
||||
## AutoRoundConfig
|
||||
|
||||
[[autodoc]] AutoRoundConfig
|
||||
|
286
docs/source/en/quantization/auto_round.md
Normal file
286
docs/source/en/quantization/auto_round.md
Normal file
@ -0,0 +1,286 @@
|
||||
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
|
||||
rendered properly in your Markdown viewer.
|
||||
-->
|
||||
|
||||
# AutoRound
|
||||
|
||||
[AutoRound](https://github.com/intel/auto-round) is an advanced quantization algorithm that delivers strong accuracy, even at 2-bit precision.
|
||||
It leverages sign gradient descent to fine-tune both rounding values and min-max clipping thresholds in just 200 steps. Designed for broad compatibility, it seamlessly supports a wide range of LLMs and is actively expanding to cover more VLMs as well.
|
||||
It also supports quantization and inference across multiple hardware platforms, including CPU, XPU, and CUDA.
|
||||
|
||||
AutoRound also offers a variety of useful features, including mixed-bit tuning and inference, lm-head quantization, support for exporting to formats like GPTQ/AWQ/GGUF, and flexible tuning recipes.
|
||||
For a comprehensive overview and the latest updates, check out the AutoRound [README](https://github.com/intel/auto-round).
|
||||
|
||||
AutoRound was originally developed as part of the [Intel Neural Compressor](https://github.com/intel/neural-compressor), serving as a general-purpose model compression library for deep learning.
|
||||
It has since evolved into a standalone library focused specifically on low-precision optimization for large language models (LLMs).
|
||||
AutoRound remains fully integrated with the Intel Neural Compressor, and you can explore the repository for more details.
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
```bash
|
||||
pip install auto-round
|
||||
```
|
||||
|
||||
## Supported Quantization Configurations
|
||||
|
||||
AutoRound supports several quantization configurations:
|
||||
|
||||
- **Int8 Weight Only**
|
||||
- **Int4 Weight Only**
|
||||
- **Int3 Weight Only**
|
||||
- **Int2 Weight Only**
|
||||
- **Mixed bits Weight only**
|
||||
|
||||
## Hardware Compatibility
|
||||
|
||||
CPU, XPU, and CUDA for both quantization and inference.
|
||||
|
||||
## Quantization and Serialization (offline)
|
||||
|
||||
Currently, only offline mode is supported to generate quantized models.
|
||||
|
||||
<hfoptions id="quantization">
|
||||
<hfoption id="quantization cmd">
|
||||
|
||||
### Command Line Usage
|
||||
```bash
|
||||
auto-round \
|
||||
--model facebook/opt-125m \
|
||||
--bits 4 \
|
||||
--group_size 128 \
|
||||
--output_dir ./tmp_autoround
|
||||
```
|
||||
|
||||
AutoRound also offer another two recipes, `auto-round-best` and `auto-round-light`, designed for optimal accuracy and improved speed, respectively.
|
||||
For 2 bits, we recommend using `auto-round-best` or `auto-round`.
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="quantization auto-round api">
|
||||
|
||||
### AutoRound API Usage
|
||||
This setting offers a better trade-off between accuracy and tuning cost, and is recommended in all scenarios.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from auto_round import AutoRound
|
||||
|
||||
model_name = "facebook/opt-125m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
bits, group_size, sym = 4, 128, True
|
||||
# mixed bits config
|
||||
# layer_config = {"model.decoder.layers.6.self_attn.out_proj": {"bits": 2, "group_size": 32}}
|
||||
autoround = AutoRound(
|
||||
model,
|
||||
tokenizer,
|
||||
bits=bits,
|
||||
group_size=group_size,
|
||||
sym=sym,
|
||||
# enable_torch_compile=True,
|
||||
# layer_config=layer_config,
|
||||
)
|
||||
|
||||
output_dir = "./tmp_autoround"
|
||||
# format= 'auto_round'(default), 'auto_gptq', 'auto_awq'
|
||||
autoround.quantize_and_save(output_dir, format='auto_round')
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="quantization auto-round-best">
|
||||
|
||||
### AutoRoundBest recipe
|
||||
This setting provides the best accuracy in most scenarios but is 4–5× slower than the standard AutoRound recipe. It is especially recommended for 2-bit quantization and is a good choice if sufficient resources are available.
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from auto_round import AutoRound
|
||||
|
||||
model_name = "facebook/opt-125m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
bits, group_size, sym = 4, 128, True
|
||||
autoround = AutoRound(
|
||||
model,
|
||||
tokenizer,
|
||||
bits=bits,
|
||||
group_size=group_size,
|
||||
sym=sym,
|
||||
nsamples=512,
|
||||
iters=1000,
|
||||
low_gpu_mem_usage=True
|
||||
)
|
||||
|
||||
output_dir = "./tmp_autoround"
|
||||
autoround.quantize_and_save(output_dir, format='auto_round')
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="quantization auto-round-light">
|
||||
|
||||
### AutoRoundLight recipe
|
||||
This setting offers the best speed (2 - 3X faster than AutoRound), but it may cause a significant accuracy drop for small models and 2-bit quantization. It is recommended for 4-bit settings and models larger than 3B.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
from auto_round import AutoRound
|
||||
|
||||
model_name = "facebook/opt-125m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
bits, group_size, sym = 4, 128, True
|
||||
autoround = AutoRound(
|
||||
model,
|
||||
tokenizer,
|
||||
bits=bits,
|
||||
group_size=group_size,
|
||||
sym=sym,
|
||||
iters=50,
|
||||
lr=5e-3,
|
||||
)
|
||||
|
||||
output_dir = "./tmp_autoround"
|
||||
autoround.quantize_and_save(output_dir, format='auto_round')
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
|
||||
</hfoptions>
|
||||
|
||||
W4G128 Average Accuracy of 13 tasks (mmlu-pro, if_eval, gsm8k, etc) and Time Cost Results (Testing was conducted on the Nvidia A100 80G using the version of PyTorch 2.6.0 with enable_torch_compile):
|
||||
|
||||
| Model | Qwen2.5-0.5B-Instruct | Falcon3-3B | Qwen2.5-7B-Instruct | Meta-Llama-3.1-8B-Instruct | Falcon3-10B | Qwen2.5-72B-Instruct |
|
||||
|---------|--------------------|---------------|------------------|----------------------------|---------------|-------------------|
|
||||
| 16bits | 0.4192 | 0.5203 | 0.6470 | 0.6212 | 0.6151 | 0.7229 |
|
||||
| Best | **0.4137**(7m) | **0.5142**(23m) | 0.6426(58m) | **0.6116**(65m) | **0.6092**(81m) | 0.7242(575m) |
|
||||
| Default | 0.4129(2m) | 0.5133(6m) | 0.6441(13m) | 0.6106(13m) | 0.6080(18m) | **0.7252**(118m) |
|
||||
| Light | 0.4052(2m) | 0.5108(3m) | **0.6453**(5m) | 0.6104(6m) | 0.6063(6m) | 0.7243(37m) |
|
||||
|
||||
## Inference
|
||||
|
||||
AutoRound automatically selects the best available backend based on the installed libraries and prompts the user to install additional libraries when a better backend is found.
|
||||
<hfoptions id="inference">
|
||||
<hfoption id="inference cpu">
|
||||
|
||||
### CPU
|
||||
|
||||
Supports 2, 4, and 8 bits. We recommend using intel-extension-for-pytorch (IPEX) for 4 bits inference.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
|
||||
```
|
||||
|
||||
<hfoption>
|
||||
|
||||
<hfoption id="inference xpu">
|
||||
|
||||
### XPU
|
||||
|
||||
Supports 4 bits only. We recommend using intel-extension-for-pytorch (IPEX) for inference.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="xpu", torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
|
||||
```
|
||||
|
||||
<hfoption>
|
||||
|
||||
<hfoption id="inference cuda">
|
||||
|
||||
### CUDA
|
||||
|
||||
Supports 2, 3, 4, and 8 bits. We recommend using GPTQModel for 4 and 8 bits inference.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
|
||||
```
|
||||
|
||||
<hfoption>
|
||||
|
||||
<hfoption id="inference backend">
|
||||
|
||||
### Specify Inference Backend
|
||||
|
||||
AutoRound automatically selects the backend for each layer based on compatibility. In general, the priority order is Marlin > ExLLaMAV2 > Triton, but the final choice depends on factors such as group size, bit width, packing format, hardware device, and other implementation details. For more details, please refer to [backends](https://github.com/intel/auto-round?tab=readme-ov-file#specify-backend),
|
||||
|
||||
The backend may not always be the most suitable for certain devices.
|
||||
You can specify your preferred backend such as "ipex" for CPU and CPU, "marlin/exllamav2/triton" for CUDA, according to your needs or hardware compatibility. Please note that additional corresponding libraries may be required.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoRoundConfig
|
||||
|
||||
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
|
||||
quantization_config = AutoRoundConfig(backend="ipex")
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", quantization_config=quantization_config, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
|
||||
```
|
||||
|
||||
<hfoption>
|
||||
|
||||
|
||||
<hfoption id="format convert">
|
||||
|
||||
### Convert GPTQ/AWQ to AutoRound
|
||||
|
||||
Most GPTQ/AWQ models can be converted to the AutoRound format for better compatibility and support with Intel devices. Please note that the quantization config will be changed if the model is serialized.
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoRoundConfig
|
||||
|
||||
model_name = "ybelkada/opt-125m-gptq-4bit"
|
||||
quantization_config = AutoRoundConfig()
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", quantization_config=quantization_config, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
|
||||
```
|
||||
|
||||
<hfoption>
|
||||
|
||||
<hfoptions>
|
||||
|
||||
## Issues
|
||||
|
||||
If you encounter any issues with the transformers integration, please open an issue on
|
||||
the [transformers](https://github.com/huggingface/transformers/issues) repository.
|
||||
If you encounter any issues with auto-round, please open an issue on
|
||||
the [AutoRound](https://github.com/intel/auto-round/issues) repository.
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
Special thanks to open-source low precision libraries such as AutoGPTQ, AutoAWQ, GPTQModel, Triton, Marlin, and ExLLaMAV2 for providing low-precision CUDA kernels, which are leveraged in AutoRound.
|
||||
|
||||
## Contribution
|
||||
Contributions to [AutoRound](https://github.com/intel/auto-round/pulls) are welcome and greatly appreciated!
|
||||
Whether it's fixing bugs, improving documentation, adding new features, or suggesting improvements, your help is always valued.
|
@ -22,25 +22,26 @@ Transformers supports many quantization methods, each with their pros and cons,
|
||||
|
||||
Use the Space below to help you pick a quantization method depending on your hardware and number of bits to quantize to.
|
||||
|
||||
| Quantization Method | On the fly quantization | CPU | CUDA GPU | ROCm GPU | Metal (Apple Silicon) | Intel GPU | Torch compile() | Bits | PEFT Fine Tuning | Serializable with 🤗Transformers | 🤗Transformers Support | Link to library |
|
||||
|-----------------------------------------------|----------------------|-----------------|----------|-----------|------------------------------------|-----------------|-----------------|---------------|------------------|-----------------------------|-------------------------|---------------------------------------------|
|
||||
| [AQLM](./aqlm) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/2 | 🟢 | 🟢 | 🟢 | https://github.com/Vahe1994/AQLM |
|
||||
| [AWQ](./awq) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | ? | 4 | 🟢 | 🟢 | 🟢 | https://github.com/casper-hansen/AutoAWQ |
|
||||
| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 | 🟢 | 🟡 | 🔴 | 🟡 | 🔴 | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes |
|
||||
| [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
|
||||
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
|
||||
| [GGUF / GGML (llama.cpp)](../gguf) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf) | [See Notes](../gguf) | https://github.com/ggerganov/llama.cpp |
|
||||
| [GPTQModel](./gptq) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel |
|
||||
| [AutoGPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
|
||||
| [HIGGS](./higgs) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 2/4 | 🔴 | 🟢 | 🟢 | https://github.com/HanGuo97/flute |
|
||||
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
||||
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
|
||||
| [FBGEMM_FP8](./fbgemm_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
| [torchao](./torchao) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
|
||||
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
|
||||
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
|
||||
| [Quark](./quark) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
|
||||
| Quantization Method | On the fly quantization | CPU | CUDA GPU | ROCm GPU | Metal (Apple Silicon) | Intel GPU | Torch compile() | Bits | PEFT Fine Tuning | Serializable with 🤗Transformers | 🤗Transformers Support | Link to library |
|
||||
|-------------------------------------------|----------------------|-----------------|----------|-----------|------------------------------------|-----------------|-----------------|--------------|------------------|-----------------------------|-------------------------|---------------------------------------------|
|
||||
| [AQLM](./aqlm) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/2 | 🟢 | 🟢 | 🟢 | https://github.com/Vahe1994/AQLM |
|
||||
| [AutoRound](./auto_round) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 🔴 | 2/3/4/8 | 🔴 | 🟢 | 🟢 | https://github.com/intel/auto-round |
|
||||
| [AWQ](./awq) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | ? | 4 | 🟢 | 🟢 | 🟢 | https://github.com/casper-hansen/AutoAWQ |
|
||||
| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 | 🟢 | 🟡 | 🔴 | 🟡 | 🔴 | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes |
|
||||
| [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
|
||||
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
|
||||
| [GGUF / GGML (llama.cpp)](../gguf) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf) | [See Notes](../gguf) | https://github.com/ggerganov/llama.cpp |
|
||||
| [GPTQModel](./gptq) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel |
|
||||
| [AutoGPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
|
||||
| [HIGGS](./higgs) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 2/4 | 🔴 | 🟢 | 🟢 | https://github.com/HanGuo97/flute |
|
||||
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
|
||||
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
|
||||
| [FBGEMM_FP8](./fbgemm_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
|
||||
| [torchao](./torchao) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
|
||||
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
|
||||
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
|
||||
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
|
||||
| [Quark](./quark) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
|
||||
|
||||
## Resources
|
||||
|
||||
|
@ -259,6 +259,7 @@ _import_structure = {
|
||||
],
|
||||
"utils.quantization_config": [
|
||||
"AqlmConfig",
|
||||
"AutoRoundConfig",
|
||||
"AwqConfig",
|
||||
"BitNetConfig",
|
||||
"BitsAndBytesConfig",
|
||||
@ -754,6 +755,7 @@ if TYPE_CHECKING:
|
||||
# bitsandbytes config
|
||||
from .utils.quantization_config import (
|
||||
AqlmConfig,
|
||||
AutoRoundConfig,
|
||||
AwqConfig,
|
||||
BitNetConfig,
|
||||
BitsAndBytesConfig,
|
||||
|
@ -19,6 +19,7 @@ from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..utils import logging
|
||||
from ..utils.quantization_config import (
|
||||
AqlmConfig,
|
||||
AutoRoundConfig,
|
||||
AwqConfig,
|
||||
BitNetConfig,
|
||||
BitsAndBytesConfig,
|
||||
@ -39,6 +40,7 @@ from ..utils.quantization_config import (
|
||||
)
|
||||
from .base import HfQuantizer
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_auto_round import AutoRoundQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
from .quantizer_bitnet import BitNetHfQuantizer
|
||||
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
|
||||
@ -75,6 +77,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"vptq": VptqHfQuantizer,
|
||||
"spqr": SpQRHfQuantizer,
|
||||
"fp8": FineGrainedFP8HfQuantizer,
|
||||
"auto-round": AutoRoundQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@ -95,6 +98,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"vptq": VptqConfig,
|
||||
"spqr": SpQRConfig,
|
||||
"fp8": FineGrainedFP8Config,
|
||||
"auto-round": AutoRoundConfig,
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@ -195,10 +199,16 @@ class AutoHfQuantizer:
|
||||
warning_msg = ""
|
||||
|
||||
if isinstance(quantization_config, dict):
|
||||
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
|
||||
# Convert the config based on the type of quantization_config_from_args (e.g., AutoRoundConfig), which takes priority before automatic configuration dispatch.
|
||||
if isinstance(quantization_config_from_args, AutoRoundConfig):
|
||||
quantization_config = AutoRoundConfig.from_dict(quantization_config)
|
||||
else:
|
||||
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
|
||||
|
||||
if (
|
||||
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config, CompressedTensorsConfig))
|
||||
isinstance(
|
||||
quantization_config, (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig)
|
||||
)
|
||||
and quantization_config_from_args is not None
|
||||
):
|
||||
# special case for GPTQ / AWQ / FbgemmFp8 config collision
|
||||
|
81
src/transformers/quantizers/quantizer_auto_round.py
Normal file
81
src/transformers/quantizers/quantizer_auto_round.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from ..utils import is_auto_round_available, is_torch_available, logging
|
||||
from ..utils.quantization_config import QuantizationConfigMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AutoRoundQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quantizer of the AutoRound method. (https://arxiv.org/pdf/2309.05516)
|
||||
"""
|
||||
|
||||
# AutoRound requires data calibration - we support only inference
|
||||
requires_calibration = True
|
||||
required_packages = ["auto_round"]
|
||||
|
||||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
self.device_map = kwargs.get("device_map", None)
|
||||
if not is_auto_round_available():
|
||||
raise ImportError(
|
||||
"Loading an AutoRound quantized model requires auto-round library (`pip install 'auto-round>=0.5'`)"
|
||||
)
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.bfloat16
|
||||
logger.info("Loading the model in `torch.bfloat16`. To overwrite it, set `torch_dtype` manually.")
|
||||
return torch_dtype
|
||||
|
||||
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
if model.__class__.main_input_name != "input_ids":
|
||||
logger.warning("AutoRound offers only limited support for models that are not strictly text-based.")
|
||||
from auto_round.inference.convert_model import convert_hf_model, infer_target_device
|
||||
|
||||
if self.pre_quantized:
|
||||
target_device = infer_target_device(self.device_map)
|
||||
model, used_backends = convert_hf_model(model, target_device)
|
||||
self.used_backends = used_backends
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
if self.pre_quantized:
|
||||
from auto_round.inference.convert_model import post_init
|
||||
|
||||
post_init(model, self.used_backends)
|
||||
else:
|
||||
raise ValueError("AutoRound only sports pre-quantized models.")
|
||||
|
||||
@property
|
||||
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
||||
return False
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
## for gptq/awq models, the quantization config will be changed
|
||||
return True
|
@ -70,6 +70,7 @@ from .utils import (
|
||||
is_aqlm_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_auto_round_available,
|
||||
is_av_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_multi_backend_available,
|
||||
@ -1297,6 +1298,13 @@ def require_auto_awq(test_case):
|
||||
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
|
||||
|
||||
|
||||
def require_auto_round(test_case):
|
||||
"""
|
||||
Decorator for auto_round dependency
|
||||
"""
|
||||
return unittest.skipUnless(is_auto_round_available(), "test requires autoround")(test_case)
|
||||
|
||||
|
||||
def require_optimum_quanto(test_case):
|
||||
"""
|
||||
Decorator for quanto dependency
|
||||
|
@ -123,6 +123,7 @@ from .import_utils import (
|
||||
is_aqlm_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_auto_round_available,
|
||||
is_av_available,
|
||||
is_bitsandbytes_available,
|
||||
is_bitsandbytes_multi_backend_available,
|
||||
|
@ -107,7 +107,7 @@ XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
||||
HQQ_MIN_VERSION = "0.2.1"
|
||||
VPTQ_MIN_VERSION = "0.0.4"
|
||||
TORCHAO_MIN_VERSION = "0.4.0"
|
||||
|
||||
AUTOROUND_MIN_VERSION = "0.5.0"
|
||||
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
_apex_available = _is_package_available("apex")
|
||||
@ -159,6 +159,7 @@ _openai_available = _is_package_available("openai")
|
||||
_optimum_available = _is_package_available("optimum")
|
||||
_auto_gptq_available = _is_package_available("auto_gptq")
|
||||
_gptqmodel_available = _is_package_available("gptqmodel")
|
||||
_auto_round_available, _auto_round_version = _is_package_available("auto_round", return_version=True)
|
||||
# `importlib.metadata.version` doesn't work with `awq`
|
||||
_auto_awq_available = importlib.util.find_spec("awq") is not None
|
||||
_quark_available = _is_package_available("quark")
|
||||
@ -1101,6 +1102,10 @@ def is_auto_awq_available():
|
||||
return _auto_awq_available
|
||||
|
||||
|
||||
def is_auto_round_available(min_version: str = AUTOROUND_MIN_VERSION):
|
||||
return _auto_round_available and version.parse(_auto_round_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_optimum_quanto_available():
|
||||
# `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto`
|
||||
return _is_optimum_quanto_available
|
||||
|
@ -63,6 +63,7 @@ class QuantizationMethod(str, Enum):
|
||||
SPQR = "spqr"
|
||||
FP8 = "fp8"
|
||||
QUARK = "quark"
|
||||
AUTOROUND = "auto-round"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
@ -204,6 +205,75 @@ class QuantizationConfigMixin:
|
||||
return unused_kwargs
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoRoundConfig(QuantizationConfigMixin):
|
||||
"""This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded AutoRound quantization.
|
||||
|
||||
Args:
|
||||
bits (`int`, *optional*, defaults to 4):
|
||||
The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
|
||||
group_size (`int`, *optional*, defaults to 128): Group-size value
|
||||
sym (`bool`, *optional*, defaults to `True`): Symmetric quantization or not
|
||||
backend (`str`, *optional*, defaults to `"auto"`): The kernel to use, e.g., ipex,marlin, exllamav2, triton, etc. Ref. https://github.com/intel/auto-round?tab=readme-ov-file#specify-backend
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bits: int = 4,
|
||||
group_size: int = 128,
|
||||
sym: bool = True,
|
||||
backend: str = "auto",
|
||||
**kwargs,
|
||||
):
|
||||
self.bits = bits
|
||||
self.group_size = group_size
|
||||
self.sym = sym
|
||||
self.backend = backend
|
||||
self.packing_format = "auto_round:gptq"
|
||||
if kwargs is not None:
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
self.quant_method = QuantizationMethod.AUTOROUND
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""Safety checker that arguments are correct."""
|
||||
if self.bits not in [2, 3, 4, 8]:
|
||||
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
|
||||
if self.group_size != -1 and self.group_size <= 0:
|
||||
raise ValueError("group_size must be greater than 0 or equal to -1")
|
||||
|
||||
def get_loading_attributes(self):
|
||||
loading_attibutes_dict = {"backend": self.backend}
|
||||
return loading_attibutes_dict
|
||||
|
||||
def to_dict(self):
|
||||
config_dict = super().to_dict()
|
||||
return config_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
quant_method = config_dict["quant_method"]
|
||||
if "auto-round" not in quant_method and "gptq" not in quant_method and "awq" not in quant_method:
|
||||
raise NotImplementedError(
|
||||
"Failed to convert to auto_round format. Only `gptqv1`, `awq`, and `auto-round` formats are supported."
|
||||
)
|
||||
|
||||
if "gptq" in quant_method and "meta" in config_dict:
|
||||
raise NotImplementedError("Failed to convert gptq format to auto_round format. Only supports `gptqv1`")
|
||||
|
||||
if "awq" in quant_method and config_dict.get("version", "gemm") != "gemm":
|
||||
raise NotImplementedError(
|
||||
"Failed to convert awq format to auto_round format. Only supports awq format with gemm version"
|
||||
)
|
||||
|
||||
if "auto-round" not in quant_method:
|
||||
config_dict["packing_format"] = f"auto_round:{quant_method}"
|
||||
|
||||
return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)
|
||||
|
||||
|
||||
@dataclass
|
||||
class HqqConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
|
0
tests/quantization/autoround/__init__.py
Normal file
0
tests/quantization/autoround/__init__.py
Normal file
209
tests/quantization/autoround/test_auto_round.py
Normal file
209
tests/quantization/autoround/test_auto_round.py
Normal file
@ -0,0 +1,209 @@
|
||||
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_auto_round,
|
||||
require_intel_extension_for_pytorch,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_auto_round
|
||||
@require_accelerate
|
||||
class AutoRoundTest(unittest.TestCase):
|
||||
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
|
||||
input_text = "There is a girl who likes adventure,"
|
||||
EXPECTED_OUTPUTS = set()
|
||||
## Different backends may produce slight variations in output
|
||||
EXPECTED_OUTPUTS.add(
|
||||
"There is a girl who likes adventure, and she has been exploring the world "
|
||||
"for many years. She travels to different countries and cultures, trying new "
|
||||
"things every day. One of her favorite places to visit is a small village in "
|
||||
"the mountains where"
|
||||
)
|
||||
EXPECTED_OUTPUTS.add(
|
||||
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She enjoys hiking through the mountains and discovering"
|
||||
)
|
||||
|
||||
device_map = "cuda"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
torch.cuda.synchronize()
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name, device_map=cls.device_map, torch_dtype=torch.float16
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
|
||||
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_raise_if_non_quantized(self):
|
||||
model_id = "facebook/opt-125m"
|
||||
quantization_config = AutoRoundConfig(bits=4)
|
||||
with self.assertRaises(ValueError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
|
||||
|
||||
def test_quantized_model_bf16(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with bf16
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
quantization_config = AutoRoundConfig(backend="triton")
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=self.device_map,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
|
||||
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
@require_intel_extension_for_pytorch
|
||||
def test_quantized_model_on_cpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto")
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
|
||||
## some backends like marlin/ipex will repack the weight that caused the weight shape changed
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
quantization_config = AutoRoundConfig(backend="triton")
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
device_map=self.device_map,
|
||||
torch_dtype=torch.float16,
|
||||
quantization_config=quantization_config,
|
||||
)
|
||||
|
||||
quantized_model.save_pretrained(tmpdirname)
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="cuda")
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=40, do_sample=False)
|
||||
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
quantization_config = AutoRoundConfig(backend="triton")
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, device_map="auto", quantization_config=quantization_config, torch_dtype="auto"
|
||||
)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_convert_from_gptq(self):
|
||||
"""
|
||||
Simple test that checks if auto-round work properly wth gptq format
|
||||
"""
|
||||
model_name = "ybelkada/opt-125m-gptq-4bit"
|
||||
|
||||
quantization_config = AutoRoundConfig()
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
|
||||
|
||||
@require_intel_extension_for_pytorch
|
||||
def test_convert_from_awq_cpu(self):
|
||||
"""
|
||||
Simple test that checks if auto-round work properly wth awq format
|
||||
"""
|
||||
model_name = "casperhansen/opt-125m-awq"
|
||||
|
||||
quantization_config = AutoRoundConfig()
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_name, device_map="cpu", quantization_config=quantization_config, torch_dtype="auto"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
|
||||
|
||||
def test_mixed_bits(self):
|
||||
"""
|
||||
Simple test that checks if auto-round work properly wth mixed bits
|
||||
"""
|
||||
model_name = "facebook/opt-125m"
|
||||
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
layer_config = {
|
||||
"model.decoder.layers.0.self_attn.k_proj": {"bits": 8},
|
||||
"model.decoder.layers.6.self_attn.out_proj": {"bits": 2, "group_size": 32},
|
||||
}
|
||||
|
||||
bits, group_size, sym = 4, 128, True
|
||||
from auto_round import AutoRound
|
||||
|
||||
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, layer_config=layer_config)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
autoround.quantize_and_save(output_dir=tmpdirname)
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, torch_dtype=torch.float16, device_map="cuda")
|
||||
text = "There is a girl who likes adventure,"
|
||||
inputs = tokenizer(text, return_tensors="pt").to(model.device)
|
||||
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
|
Loading…
Reference in New Issue
Block a user