transformers/docs/source/en/quantization/overview.md
jiqing-feng 387663e571
Enable gptqmodel (#35012)
* gptqmodel

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update readme

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* gptqmodel need use checkpoint_format (#1)

* gptqmodel need use checkpoint_format

* fix quantize

* Update quantization_config.py

* Update quantization_config.py

* Update quantization_config.py

---------

Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>

* Revert quantizer_gptq.py (#2)

* revert quantizer_gptq.py change

* pass **kwargs

* limit gptqmodel and optimum version

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix warning

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix version check

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* revert unrelated changes

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* enable gptqmodel tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix requires gptq

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Fix Transformer compat (#3)

* revert quantizer_gptq.py change

* pass **kwargs

* add meta info

* cleanup

* cleanup

* Update quantization_config.py

* hf_select_quant_linear pass checkpoint_format and meta

* fix GPTQTestCUDA

* Update test_gptq.py

* gptqmodel.hf_select_quant_linear() now does not select ExllamaV2

* cleanup

* add backend

* cleanup

* cleanup

* no need check exllama version

* Update quantization_config.py

* lower checkpoint_format and backend

* check none

* cleanup

* Update quantization_config.py

* fix self.use_exllama == False

* spell

* fix unittest

* fix unittest

---------

Co-authored-by: LRL <lrl@lbx.dev>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>

* fix format

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix format again

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update gptqmodel version (#6)

* update gptqmodel version

* update gptqmodel version

* fix unit test (#5)

* update gptqmodel version

* update gptqmodel version

* "not self.use_exllama" is not equivalent to "self.use_exllama==False"

* fix unittest

* update gptqmodel version

* backend is loading_attibutes (#7)

* fix format and tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix memory check

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix device mismatch

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* fix result check

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* Update src/transformers/quantizers/quantizer_gptq.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_gptq.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/quantizers/quantizer_gptq.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* update tests

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* review: update docs (#10)

* review: update docs (#12)

* review: update docs

* fix typo

* update tests for gptqmodel

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>

* update document (#9)

* update overview.md

* cleanup

* Update overview.md

* Update overview.md

* Update overview.md

* update gptq.md

* Update gptq.md

* Update gptq.md

* Update gptq.md

* Update gptq.md

* Update gptq.md

* Update gptq.md

---------

Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>

* typo

* doc note for asymmetric quant

* typo with apple silicon(e)

* typo for marlin

* column name revert: review

* doc rocm support

* Update docs/source/en/quantization/gptq.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/gptq.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/gptq.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/gptq.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/overview.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

* Update docs/source/en/quantization/overview.md

Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>

---------

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: LRL-ModelCloud <165116337+LRL-ModelCloud@users.noreply.github.com>
Co-authored-by: ZX-ModelCloud <zx@modelcloud.ai>
Co-authored-by: Qubitium-ModelCloud <qubitium@modelcloud.ai>
Co-authored-by: ZX-ModelCloud <165115237+ZX-ModelCloud@users.noreply.github.com>
Co-authored-by: LRL <lrl@lbx.dev>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
2025-01-15 14:22:49 +01:00

95 lines
9.0 KiB
Markdown

<!--Copyright 2023 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Quantization
Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
<Tip>
Interested in adding a new quantization method to Transformers? Read the [HfQuantizer](./contribute) guide to learn how!
</Tip>
<Tip>
If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI:
* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)
</Tip>
## When to use what?
The community has developed many quantization methods for various use cases. With Transformers, you can run any of these integrated methods depending on your use case because each method has their own pros and cons.
For example, some quantization methods require calibrating the model with a dataset for more accurate and "extreme" compression (up to 1-2 bits quantization), while other methods work out of the box with on-the-fly quantization.
Another parameter to consider is compatibility with your target device. Do you want to quantize on a CPU, GPU, or Apple silicon?
In short, supporting a wide range of quantization methods allows you to pick the best quantization method for your specific use case.
Use the table below to help you decide which quantization method to use.
| Quantization Method | On the fly quantization | CPU | CUDA GPU | ROCm GPU | Metal (Apple Silicon) | Intel GPU | Torch compile() | Bits | PEFT Fine Tuning | Serializable with 🤗Transformers | 🤗Transformers Support | Link to library |
|-----------------------------------------------|----------------------|-----------------|----------|-----------|------------------------------------|-----------------|-----------------|---------------|------------------|-----------------------------|-------------------------|---------------------------------------------|
| [AQLM](./aqlm.md) | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/2 | 🟢 | 🟢 | 🟢 | https://github.com/Vahe1994/AQLM |
| [AWQ](./awq.md) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | ? | 4 | 🟢 | 🟢 | 🟢 | https://github.com/casper-hansen/AutoAWQ |
| [bitsandbytes](./bitsandbytes.md) | 🟢 | 🟡 <sub>1</sub> | 🟢 | 🟡 <sub>1</sub> | 🔴 <sub>2</sub> | 🟡 <sub>1</sub> | 🔴 <sub>1</sub> | 4/8 | 🟢 | 🟢 | 🟢 | https://github.com/bitsandbytes-foundation/bitsandbytes |
| [compressed-tensors](./compressed_tensors.md) | 🔴 | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 1/8 | 🟢 | 🟢 | 🟢 | https://github.com/neuralmagic/compressed-tensors |
| [EETQ](./eetq.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | ? | 8 | 🟢 | 🟢 | 🟢 | https://github.com/NetEase-FuXi/EETQ |
| [GGUF / GGML (llama.cpp)](../gguf.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 1/8 | 🔴 | [See Notes](../gguf.md) | [See Notes](../gguf.md) | https://github.com/ggerganov/llama.cpp |
| [GPTQModel](./gptq.md) | 🔴 | 🟢 <sub>3</sub> | 🟢 | 🟢 | 🟢 | 🟢 <sub>4</sub> | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/ModelCloud/GPTQModel |
| [AutoGPTQ](./gptq.md) | 🔴 | 🔴 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 2/3/4/8 | 🟢 | 🟢 | 🟢 | https://github.com/AutoGPTQ/AutoGPTQ |
| [HIGGS](./higgs.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 2/4 | 🔴 | 🟢 | 🟢 | https://github.com/HanGuo97/flute |
| [HQQ](./hqq.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🔴 | 🔴 | 🟢 | 1/8 | 🟢 | 🔴 | 🟢 | https://github.com/mobiusml/hqq/ |
| [optimum-quanto](./quanto.md) | 🟢 | 🟢 | 🟢 | 🔴 | 🟢 | 🔴 | 🟢 | 2/4/8 | 🔴 | 🔴 | 🟢 | https://github.com/huggingface/optimum-quanto |
| [FBGEMM_FP8](./fbgemm_fp8.md) | 🟢 | 🔴 | 🟢 | 🔴 | 🔴 | 🔴 | 🔴 | 8 | 🔴 | 🟢 | 🟢 | https://github.com/pytorch/FBGEMM |
| [torchao](./torchao.md) | 🟢 | | 🟢 | 🔴 | 🟡 <sub>5</sub> | 🔴 | | 4/8 | | 🟢🔴 | 🟢 | https://github.com/pytorch/ao |
| [VPTQ](./vptq.md) | 🔴 | 🔴 | 🟢 | 🟡 | 🔴 | 🔴 | 🟢 | 1/8 | 🔴 | 🟢 | 🟢 | https://github.com/microsoft/VPTQ |
<Tip>
**1:** bitsandbytes is being refactored to support multiple backends beyond CUDA. Currently, ROCm (AMD GPU) and Intel CPU implementations are mature, with Intel XPU in progress and Apple Silicon support expected by Q4/Q1. For installation instructions and the latest backend updates, visit [this link](https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend). Check out [these docs](https://huggingface.co/docs/bitsandbytes/main/en/non_cuda_backends) for more details and feedback links.
</Tip>
<Tip>
**2:** bitsandbytes is seeking contributors to help develop and lead the Apple Silicon backend. Interested? Contact them directly via their repo. Stipends may be available through sponsorships.
</Tip>
<Tip>
**3:** GPTQModel[CPU] supports 4-bit via IPEX on Intel/AMD and full bit range via Torch on Intel/AMD/Apple Silicon.
</Tip>
<Tip>
**4:** GPTQModel[Intel GPU] via IPEX only supports 4-bit for Intel Datacenter Max/Arc GPUs.
</Tip>
<Tip>
**5:** torchao only supports int4 weight on Metal (Apple Silicon).
</Tip>