Add AutoRound quantization support (#37393)

* add auto-round support

* Update src/transformers/quantizers/auto.py

Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>

* fix style issue

Signed-off-by: wenhuach <wenhuach87@gmail.com>

* tiny change

* tiny change

* refine ut and doc

* revert unnecessary change

* tiny change

* try to fix style issue

* try to fix style issue

* try to fix style issue

* try to fix style issue

* try to fix style issue

* try to fix style issue

* try to fix style issue

* fix doc issue

* Update tests/quantization/autoround/test_auto_round.py

* fix comments

* Update tests/quantization/autoround/test_auto_round.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/autoround/test_auto_round.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* update doc

* Update src/transformers/quantizers/quantizer_auto_round.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* update

* update

* fix

* try to fix style issue

* Update src/transformers/quantizers/auto.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update docs/source/en/quantization/auto_round.md

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update docs/source/en/quantization/auto_round.md

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* Update docs/source/en/quantization/auto_round.md

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* update

* fix style issue

* update doc

* update doc

* Refine the doc

* refine doc

* revert one change

* set sym to True by default

* Enhance the unit test's robustness.

* update

* add torch dtype

* tiny change

* add awq convert test

* fix typo

* update

* fix packing format issue

* use one gpu

---------

Signed-off-by: wenhuach <wenhuach87@gmail.com>
Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Shen, Haihao <haihao.shen@intel.com>
This commit is contained in:
Wenhua Cheng 2025-04-22 19:56:54 +08:00 committed by GitHub
parent 9608908639
commit b3492ff9f7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 702 additions and 22 deletions

View File

@ -84,6 +84,9 @@ RUN python3 -m pip install --no-cache-dir compressed-tensors
# Add AMD Quark for quantization testing
RUN python3 -m pip install --no-cache-dir amd-quark
# Add AutoRound for quantization testing
RUN python3 -m pip install --no-cache-dir "auto-round>=0.5.0"
# Add transformers in editable mode
RUN python3 -m pip install --no-cache-dir -e ./transformers[dev-torch]

View File

@ -92,3 +92,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## QuarkConfig
[[autodoc]] QuarkConfig
## AutoRoundConfig
[[autodoc]] AutoRoundConfig

View File

@ -0,0 +1,286 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# AutoRound
[AutoRound](https://github.com/intel/auto-round) is an advanced quantization algorithm that delivers strong accuracy, even at 2-bit precision.
It leverages sign gradient descent to fine-tune both rounding values and min-max clipping thresholds in just 200 steps. Designed for broad compatibility, it seamlessly supports a wide range of LLMs and is actively expanding to cover more VLMs as well.
It also supports quantization and inference across multiple hardware platforms, including CPU, XPU, and CUDA.
AutoRound also offers a variety of useful features, including mixed-bit tuning and inference, lm-head quantization, support for exporting to formats like GPTQ/AWQ/GGUF, and flexible tuning recipes.
For a comprehensive overview and the latest updates, check out the AutoRound [README](https://github.com/intel/auto-round).
AutoRound was originally developed as part of the [Intel Neural Compressor](https://github.com/intel/neural-compressor), serving as a general-purpose model compression library for deep learning.
It has since evolved into a standalone library focused specifically on low-precision optimization for large language models (LLMs).
AutoRound remains fully integrated with the Intel Neural Compressor, and you can explore the repository for more details.
## Installation
```bash
pip install auto-round
```
## Supported Quantization Configurations
AutoRound supports several quantization configurations:
- **Int8 Weight Only**
- **Int4 Weight Only**
- **Int3 Weight Only**
- **Int2 Weight Only**
- **Mixed bits Weight only**
## Hardware Compatibility
CPU, XPU, and CUDA for both quantization and inference.
## Quantization and Serialization (offline)
Currently, only offline mode is supported to generate quantized models.
<hfoptions id="quantization">
<hfoption id="quantization cmd">
### Command Line Usage
```bash
auto-round \
--model facebook/opt-125m \
--bits 4 \
--group_size 128 \
--output_dir ./tmp_autoround
```
AutoRound also offer another two recipes, `auto-round-best` and `auto-round-light`, designed for optimal accuracy and improved speed, respectively.
For 2 bits, we recommend using `auto-round-best` or `auto-round`.
</hfoption>
<hfoption id="quantization auto-round api">
### AutoRound API Usage
This setting offers a better trade-off between accuracy and tuning cost, and is recommended in all scenarios.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRound
model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
bits, group_size, sym = 4, 128, True
# mixed bits config
# layer_config = {"model.decoder.layers.6.self_attn.out_proj": {"bits": 2, "group_size": 32}}
autoround = AutoRound(
model,
tokenizer,
bits=bits,
group_size=group_size,
sym=sym,
# enable_torch_compile=True,
# layer_config=layer_config,
)
output_dir = "./tmp_autoround"
# format= 'auto_round'(default), 'auto_gptq', 'auto_awq'
autoround.quantize_and_save(output_dir, format='auto_round')
```
</hfoption>
<hfoption id="quantization auto-round-best">
### AutoRoundBest recipe
This setting provides the best accuracy in most scenarios but is 45× slower than the standard AutoRound recipe. It is especially recommended for 2-bit quantization and is a good choice if sufficient resources are available.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRound
model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
bits, group_size, sym = 4, 128, True
autoround = AutoRound(
model,
tokenizer,
bits=bits,
group_size=group_size,
sym=sym,
nsamples=512,
iters=1000,
low_gpu_mem_usage=True
)
output_dir = "./tmp_autoround"
autoround.quantize_and_save(output_dir, format='auto_round')
```
</hfoption>
<hfoption id="quantization auto-round-light">
### AutoRoundLight recipe
This setting offers the best speed (2 - 3X faster than AutoRound), but it may cause a significant accuracy drop for small models and 2-bit quantization. It is recommended for 4-bit settings and models larger than 3B.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round import AutoRound
model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
bits, group_size, sym = 4, 128, True
autoround = AutoRound(
model,
tokenizer,
bits=bits,
group_size=group_size,
sym=sym,
iters=50,
lr=5e-3,
)
output_dir = "./tmp_autoround"
autoround.quantize_and_save(output_dir, format='auto_round')
```
</hfoption>
</hfoptions>
W4G128 Average Accuracy of 13 tasks (mmlu-pro, if_eval, gsm8k, etc) and Time Cost Results (Testing was conducted on the Nvidia A100 80G using the version of PyTorch 2.6.0 with enable_torch_compile):
| Model | Qwen2.5-0.5B-Instruct | Falcon3-3B | Qwen2.5-7B-Instruct | Meta-Llama-3.1-8B-Instruct | Falcon3-10B | Qwen2.5-72B-Instruct |
|---------|--------------------|---------------|------------------|----------------------------|---------------|-------------------|
| 16bits | 0.4192 | 0.5203 | 0.6470 | 0.6212 | 0.6151 | 0.7229 |
| Best | **0.4137**(7m) | **0.5142**(23m) | 0.6426(58m) | **0.6116**(65m) | **0.6092**(81m) | 0.7242(575m) |
| Default | 0.4129(2m) | 0.5133(6m) | 0.6441(13m) | 0.6106(13m) | 0.6080(18m) | **0.7252**(118m) |
| Light | 0.4052(2m) | 0.5108(3m) | **0.6453**(5m) | 0.6104(6m) | 0.6063(6m) | 0.7243(37m) |
## Inference
AutoRound automatically selects the best available backend based on the installed libraries and prompts the user to install additional libraries when a better backend is found.
<hfoptions id="inference">
<hfoption id="inference cpu">
### CPU
Supports 2, 4, and 8 bits. We recommend using intel-extension-for-pytorch (IPEX) for 4 bits inference.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
```
<hfoption>
<hfoption id="inference xpu">
### XPU
Supports 4 bits only. We recommend using intel-extension-for-pytorch (IPEX) for inference.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="xpu", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
```
<hfoption>
<hfoption id="inference cuda">
### CUDA
Supports 2, 3, 4, and 8 bits. We recommend using GPTQModel for 4 and 8 bits inference.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
```
<hfoption>
<hfoption id="inference backend">
### Specify Inference Backend
AutoRound automatically selects the backend for each layer based on compatibility. In general, the priority order is Marlin > ExLLaMAV2 > Triton, but the final choice depends on factors such as group size, bit width, packing format, hardware device, and other implementation details. For more details, please refer to [backends](https://github.com/intel/auto-round?tab=readme-ov-file#specify-backend),
The backend may not always be the most suitable for certain devices.
You can specify your preferred backend such as "ipex" for CPU and CPU, "marlin/exllamav2/triton" for CUDA, according to your needs or hardware compatibility. Please note that additional corresponding libraries may be required.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoRoundConfig
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
quantization_config = AutoRoundConfig(backend="ipex")
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", quantization_config=quantization_config, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
```
<hfoption>
<hfoption id="format convert">
### Convert GPTQ/AWQ to AutoRound
Most GPTQ/AWQ models can be converted to the AutoRound format for better compatibility and support with Intel devices. Please note that the quantization config will be changed if the model is serialized.
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoRoundConfig
model_name = "ybelkada/opt-125m-gptq-4bit"
quantization_config = AutoRoundConfig()
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cpu", quantization_config=quantization_config, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50, do_sample=False)[0]))
```
<hfoption>
<hfoptions>
## Issues
If you encounter any issues with the transformers integration, please open an issue on
the [transformers](https://github.com/huggingface/transformers/issues) repository.
If you encounter any issues with auto-round, please open an issue on
the [AutoRound](https://github.com/intel/auto-round/issues) repository.
## Acknowledgement
Special thanks to open-source low precision libraries such as AutoGPTQ, AutoAWQ, GPTQModel, Triton, Marlin, and ExLLaMAV2 for providing low-precision CUDA kernels, which are leveraged in AutoRound.
## Contribution
Contributions to [AutoRound](https://github.com/intel/auto-round/pulls) are welcome and greatly appreciated!
Whether it's fixing bugs, improving documentation, adding new features, or suggesting improvements, your help is always valued.

View File

@ -22,25 +22,26 @@ Transformers supports many quantization methods, each with their pros and cons,
Use the Space below to help you pick a quantization method depending on your hardware and number of bits to quantize to.
| Quantization Method | On the fly quantization | CPU | CUDA GPU | ROCm GPU | Metal (Apple Silicon) | Intel GPU | Torch compile() | Bits | PEFT Fine Tuning | Serializable with 🤗Transformers | 🤗Transformers Support | Link to library |
|-----------------------------------------------|----------------------|-----------------|----------|-----------|------------------------------------|-----------------|-----------------|---------------|------------------|-----------------------------|-------------------------|---------------------------------------------|
| [AQLM](./aqlm) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/2 | 🟢 | 🟢 | 🟢 | https://github.com/Vahe1994/AQLM |
| [AWQ](./awq) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | ? | 4 | 🟢 | 🟢 | 🟢 | https://github.com/casper-hansen/AutoAWQ |
| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 | 🟢 | 🟡 | 🔴 | 🟡 | 🔴 | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes |
| [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
| [GGUF / GGML (llama.cpp)](../gguf) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf) | [See Notes](../gguf) | https://github.com/ggerganov/llama.cpp |
| [GPTQModel](./gptq) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel |
| [AutoGPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
| [HIGGS](./higgs) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 2/4 | 🔴 | 🟢 | 🟢 | https://github.com/HanGuo97/flute |
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
| [Quark](./quark) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
| Quantization Method | On the fly quantization | CPU | CUDA GPU | ROCm GPU | Metal (Apple Silicon) | Intel GPU | Torch compile() | Bits | PEFT Fine Tuning | Serializable with 🤗Transformers | 🤗Transformers Support | Link to library |
|-------------------------------------------|----------------------|-----------------|----------|-----------|------------------------------------|-----------------|-----------------|--------------|------------------|-----------------------------|-------------------------|---------------------------------------------|
| [AQLM](./aqlm) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/2 | 🟢 | 🟢 | 🟢 | https://github.com/Vahe1994/AQLM |
| [AutoRound](./auto_round) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🟢 | 🔴 | 2/3/4/8 | 🔴 | 🟢 | 🟢 | https://github.com/intel/auto-round |
| [AWQ](./awq) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | ? | 4 | 🟢 | 🟢 | 🟢 | https://github.com/casper-hansen/AutoAWQ |
| [bitsandbytes](./bitsandbytes) | 🟢 | 🟡 | 🟢 | 🟡 | 🔴 | 🟡 | 🔴 | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes |
| [compressed-tensors](./compressed_tensors) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
| [EETQ](./eetq) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
| [GGUF / GGML (llama.cpp)](../gguf) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf) | [See Notes](../gguf) | https://github.com/ggerganov/llama.cpp |
| [GPTQModel](./gptq) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel |
| [AutoGPTQ](./gptq) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
| [HIGGS](./higgs) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 2/4 | 🔴 | 🟢 | 🟢 | https://github.com/HanGuo97/flute |
| [HQQ](./hqq) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [optimum-quanto](./quanto) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao) | 🟢 | 🟢 | 🟢 | 🔴 | 🟡 | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
| [FINEGRAINED_FP8](./finegrained_fp8) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | |
| [SpQR](./spqr) | 🔴 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 3 | 🔴 | 🟢 | 🟢 | https://github.com/Vahe1994/SpQR/ |
| [Quark](./quark) | 🔴 | 🟢 | 🟢 | 🟢 | 🟢 | 🟢 | ? | 2/4/6/8/9/16 | 🔴 | 🔴 | 🟢 | https://quark.docs.amd.com/latest/ |
## Resources

View File

@ -259,6 +259,7 @@ _import_structure = {
],
"utils.quantization_config": [
"AqlmConfig",
"AutoRoundConfig",
"AwqConfig",
"BitNetConfig",
"BitsAndBytesConfig",
@ -754,6 +755,7 @@ if TYPE_CHECKING:
# bitsandbytes config
from .utils.quantization_config import (
AqlmConfig,
AutoRoundConfig,
AwqConfig,
BitNetConfig,
BitsAndBytesConfig,

View File

@ -19,6 +19,7 @@ from ..models.auto.configuration_auto import AutoConfig
from ..utils import logging
from ..utils.quantization_config import (
AqlmConfig,
AutoRoundConfig,
AwqConfig,
BitNetConfig,
BitsAndBytesConfig,
@ -39,6 +40,7 @@ from ..utils.quantization_config import (
)
from .base import HfQuantizer
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_auto_round import AutoRoundQuantizer
from .quantizer_awq import AwqQuantizer
from .quantizer_bitnet import BitNetHfQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
@ -75,6 +77,7 @@ AUTO_QUANTIZER_MAPPING = {
"vptq": VptqHfQuantizer,
"spqr": SpQRHfQuantizer,
"fp8": FineGrainedFP8HfQuantizer,
"auto-round": AutoRoundQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@ -95,6 +98,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"vptq": VptqConfig,
"spqr": SpQRConfig,
"fp8": FineGrainedFP8Config,
"auto-round": AutoRoundConfig,
}
logger = logging.get_logger(__name__)
@ -195,10 +199,16 @@ class AutoHfQuantizer:
warning_msg = ""
if isinstance(quantization_config, dict):
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
# Convert the config based on the type of quantization_config_from_args (e.g., AutoRoundConfig), which takes priority before automatic configuration dispatch.
if isinstance(quantization_config_from_args, AutoRoundConfig):
quantization_config = AutoRoundConfig.from_dict(quantization_config)
else:
quantization_config = AutoQuantizationConfig.from_dict(quantization_config)
if (
isinstance(quantization_config, (GPTQConfig, AwqConfig, FbgemmFp8Config, CompressedTensorsConfig))
isinstance(
quantization_config, (GPTQConfig, AwqConfig, AutoRoundConfig, FbgemmFp8Config, CompressedTensorsConfig)
)
and quantization_config_from_args is not None
):
# special case for GPTQ / AWQ / FbgemmFp8 config collision

View File

@ -0,0 +1,81 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_round_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class AutoRoundQuantizer(HfQuantizer):
"""
Quantizer of the AutoRound method. (https://arxiv.org/pdf/2309.05516)
"""
# AutoRound requires data calibration - we support only inference
requires_calibration = True
required_packages = ["auto_round"]
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
def validate_environment(self, *args, **kwargs):
self.device_map = kwargs.get("device_map", None)
if not is_auto_round_available():
raise ImportError(
"Loading an AutoRound quantized model requires auto-round library (`pip install 'auto-round>=0.5'`)"
)
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
torch_dtype = torch.bfloat16
logger.info("Loading the model in `torch.bfloat16`. To overwrite it, set `torch_dtype` manually.")
return torch_dtype
def _process_model_before_weight_loading(self, model: "PreTrainedModel", **kwargs):
if model.__class__.main_input_name != "input_ids":
logger.warning("AutoRound offers only limited support for models that are not strictly text-based.")
from auto_round.inference.convert_model import convert_hf_model, infer_target_device
if self.pre_quantized:
target_device = infer_target_device(self.device_map)
model, used_backends = convert_hf_model(model, target_device)
self.used_backends = used_backends
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
if self.pre_quantized:
from auto_round.inference.convert_model import post_init
post_init(model, self.used_backends)
else:
raise ValueError("AutoRound only sports pre-quantized models.")
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
def is_serializable(self, safe_serialization=None):
## for gptq/awq models, the quantization config will be changed
return True

View File

@ -70,6 +70,7 @@ from .utils import (
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_auto_round_available,
is_av_available,
is_bitsandbytes_available,
is_bitsandbytes_multi_backend_available,
@ -1297,6 +1298,13 @@ def require_auto_awq(test_case):
return unittest.skipUnless(is_auto_awq_available(), "test requires autoawq")(test_case)
def require_auto_round(test_case):
"""
Decorator for auto_round dependency
"""
return unittest.skipUnless(is_auto_round_available(), "test requires autoround")(test_case)
def require_optimum_quanto(test_case):
"""
Decorator for quanto dependency

View File

@ -123,6 +123,7 @@ from .import_utils import (
is_aqlm_available,
is_auto_awq_available,
is_auto_gptq_available,
is_auto_round_available,
is_av_available,
is_bitsandbytes_available,
is_bitsandbytes_multi_backend_available,

View File

@ -107,7 +107,7 @@ XLA_FSDPV2_MIN_VERSION = "2.2.0"
HQQ_MIN_VERSION = "0.2.1"
VPTQ_MIN_VERSION = "0.0.4"
TORCHAO_MIN_VERSION = "0.4.0"
AUTOROUND_MIN_VERSION = "0.5.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex")
@ -159,6 +159,7 @@ _openai_available = _is_package_available("openai")
_optimum_available = _is_package_available("optimum")
_auto_gptq_available = _is_package_available("auto_gptq")
_gptqmodel_available = _is_package_available("gptqmodel")
_auto_round_available, _auto_round_version = _is_package_available("auto_round", return_version=True)
# `importlib.metadata.version` doesn't work with `awq`
_auto_awq_available = importlib.util.find_spec("awq") is not None
_quark_available = _is_package_available("quark")
@ -1101,6 +1102,10 @@ def is_auto_awq_available():
return _auto_awq_available
def is_auto_round_available(min_version: str = AUTOROUND_MIN_VERSION):
return _auto_round_available and version.parse(_auto_round_version) >= version.parse(min_version)
def is_optimum_quanto_available():
# `importlib.metadata.version` doesn't work with `optimum.quanto`, need to put `optimum_quanto`
return _is_optimum_quanto_available

View File

@ -63,6 +63,7 @@ class QuantizationMethod(str, Enum):
SPQR = "spqr"
FP8 = "fp8"
QUARK = "quark"
AUTOROUND = "auto-round"
class AWQLinearVersion(str, Enum):
@ -204,6 +205,75 @@ class QuantizationConfigMixin:
return unused_kwargs
@dataclass
class AutoRoundConfig(QuantizationConfigMixin):
"""This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded AutoRound quantization.
Args:
bits (`int`, *optional*, defaults to 4):
The number of bits to quantize to, supported numbers are (2, 3, 4, 8).
group_size (`int`, *optional*, defaults to 128): Group-size value
sym (`bool`, *optional*, defaults to `True`): Symmetric quantization or not
backend (`str`, *optional*, defaults to `"auto"`): The kernel to use, e.g., ipex,marlin, exllamav2, triton, etc. Ref. https://github.com/intel/auto-round?tab=readme-ov-file#specify-backend
"""
def __init__(
self,
bits: int = 4,
group_size: int = 128,
sym: bool = True,
backend: str = "auto",
**kwargs,
):
self.bits = bits
self.group_size = group_size
self.sym = sym
self.backend = backend
self.packing_format = "auto_round:gptq"
if kwargs is not None:
for key, value in kwargs.items():
setattr(self, key, value)
self.quant_method = QuantizationMethod.AUTOROUND
self.post_init()
def post_init(self):
r"""Safety checker that arguments are correct."""
if self.bits not in [2, 3, 4, 8]:
raise ValueError(f"Only support quantization to [2,3,4,8] bits but found {self.bits}")
if self.group_size != -1 and self.group_size <= 0:
raise ValueError("group_size must be greater than 0 or equal to -1")
def get_loading_attributes(self):
loading_attibutes_dict = {"backend": self.backend}
return loading_attibutes_dict
def to_dict(self):
config_dict = super().to_dict()
return config_dict
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
quant_method = config_dict["quant_method"]
if "auto-round" not in quant_method and "gptq" not in quant_method and "awq" not in quant_method:
raise NotImplementedError(
"Failed to convert to auto_round format. Only `gptqv1`, `awq`, and `auto-round` formats are supported."
)
if "gptq" in quant_method and "meta" in config_dict:
raise NotImplementedError("Failed to convert gptq format to auto_round format. Only supports `gptqv1`")
if "awq" in quant_method and config_dict.get("version", "gemm") != "gemm":
raise NotImplementedError(
"Failed to convert awq format to auto_round format. Only supports awq format with gemm version"
)
if "auto-round" not in quant_method:
config_dict["packing_format"] = f"auto_round:{quant_method}"
return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)
@dataclass
class HqqConfig(QuantizationConfigMixin):
"""

View File

View File

@ -0,0 +1,209 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AutoModelForCausalLM, AutoRoundConfig, AutoTokenizer
from transformers.testing_utils import (
require_accelerate,
require_auto_round,
require_intel_extension_for_pytorch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_torch_available
if is_torch_available():
import torch
@slow
@require_torch_gpu
@require_auto_round
@require_accelerate
class AutoRoundTest(unittest.TestCase):
model_name = "OPEA/Qwen2.5-1.5B-Instruct-int4-sym-inc"
input_text = "There is a girl who likes adventure,"
EXPECTED_OUTPUTS = set()
## Different backends may produce slight variations in output
EXPECTED_OUTPUTS.add(
"There is a girl who likes adventure, and she has been exploring the world "
"for many years. She travels to different countries and cultures, trying new "
"things every day. One of her favorite places to visit is a small village in "
"the mountains where"
)
EXPECTED_OUTPUTS.add(
"There is a girl who likes adventure, and she has been exploring the world for many years. She has visited every country in Europe and has even traveled to some of the most remote parts of Africa. She enjoys hiking through the mountains and discovering"
)
device_map = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
torch.cuda.synchronize()
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name, device_map=cls.device_map, torch_dtype=torch.float16
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_raise_if_non_quantized(self):
model_id = "facebook/opt-125m"
quantization_config = AutoRoundConfig(bits=4)
with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
def test_quantized_model_bf16(self):
"""
Simple test that checks if the quantized model is working properly with bf16
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = AutoRoundConfig(backend="triton")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=self.device_map,
quantization_config=quantization_config,
)
output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_intel_extension_for_pytorch
def test_quantized_model_on_cpu(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt")
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, torch_dtype="auto")
output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
## some backends like marlin/ipex will repack the weight that caused the weight shape changed
with tempfile.TemporaryDirectory() as tmpdirname:
quantization_config = AutoRoundConfig(backend="triton")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
device_map=self.device_map,
torch_dtype=torch.float16,
quantization_config=quantization_config,
)
quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map="cuda")
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = AutoRoundConfig(backend="triton")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, device_map="auto", quantization_config=quantization_config, torch_dtype="auto"
)
output = quantized_model.generate(**input_ids, max_new_tokens=40, do_sample=False)
self.assertIn(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_convert_from_gptq(self):
"""
Simple test that checks if auto-round work properly wth gptq format
"""
model_name = "ybelkada/opt-125m-gptq-4bit"
quantization_config = AutoRoundConfig()
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="cuda", quantization_config=quantization_config, torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
@require_intel_extension_for_pytorch
def test_convert_from_awq_cpu(self):
"""
Simple test that checks if auto-round work properly wth awq format
"""
model_name = "casperhansen/opt-125m-awq"
quantization_config = AutoRoundConfig()
model = AutoModelForCausalLM.from_pretrained(
model_name, device_map="cpu", quantization_config=quantization_config, torch_dtype="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])
def test_mixed_bits(self):
"""
Simple test that checks if auto-round work properly wth mixed bits
"""
model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto")
tokenizer = AutoTokenizer.from_pretrained(model_name)
layer_config = {
"model.decoder.layers.0.self_attn.k_proj": {"bits": 8},
"model.decoder.layers.6.self_attn.out_proj": {"bits": 2, "group_size": 32},
}
bits, group_size, sym = 4, 128, True
from auto_round import AutoRound
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym, layer_config=layer_config)
with tempfile.TemporaryDirectory() as tmpdirname:
autoround.quantize_and_save(output_dir=tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, torch_dtype=torch.float16, device_map="cuda")
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
tokenizer.decode(model.generate(**inputs, max_new_tokens=5)[0])