transformers/docs/source/en/quantization/higgs.md
Quentin Gallouédec de24fb63ed
Use HF papers (#38184)
* Use hf papers

* Hugging Face papers

* doi to hf papers

* style
2025-06-13 11:07:09 +00:00

3.1 KiB

HIGGS

HIGGS is a zero-shot quantization algorithm that combines Hadamard preprocessing with MSE-Optimal quantization grids to achieve lower quantization error and state-of-the-art performance.

Runtime support for HIGGS is implemented through the FLUTE library. Only the 70B and 405B variants of Llama 3 and Llama 3.0, and the 8B and 27B variants of Gemma 2 are currently supported. HIGGS also doesn't support quantized training and backward passes in general at the moment.

Run the command below to install FLUTE.

pip install flute-kernel
pip install flute-kernel -i https://flute-ai.github.io/whl/cu12.4

Create a [HiggsConfig] with the number of bits to quantize a model to.

from transformers import AutoModelForCausalLM, AutoTokenizer, HiggsConfig

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    quantization_config=HiggsConfig(bits=4),
    device_map="auto",
)

Tip

Find models pre-quantized with HIGGS in the official ISTA-DASLab collection.

torch.compile

HIGGS is fully compatible with torch.compile.

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, HiggsConfig

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    quantization_config=HiggsConfig(bits=4),
    device_map="auto",
)

model = torch.compile(model)

Refer to the table below for a benchmark of forward passes/sec for Llama-3.1-8B-Instruct on a RTX4090.

Batch Size BF16 (with torch.compile) HIGGS 4bit (without torch.compile) HIGGS 4bit (with torch.compile)
1 59 41 124
4 57 42 123
16 56 41 120