transformers/docs/source/en/quantization/higgs.md
Andrei Panferov 64c05eecd6
HIGGS Quantization Support (#34997)
* higgs init

* working with crunches

* per-model workspaces

* style

* style 2

* tests and style

* higgs tests passing

* protecting torch import

* removed torch.Tensor type annotations

* torch.nn.Module inheritance fix maybe

* hide inputs inside quantizer calls

* style structure something

* Update src/transformers/quantizers/quantizer_higgs.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* reworked num_sms

* Update src/transformers/integrations/higgs.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* revamped device checks

* docstring upd

* Update src/transformers/quantizers/quantizer_higgs.py

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>

* edited tests and device map assertions

* minor edits

* updated flute cuda version in docker

* Added p=1 and 2,3bit HIGGS

* flute version check update

* incorporated `modules_to_not_convert`

* less hardcoding

* Fixed comment

* Added docs

* Fixed gemma support

* example in docs

* fixed torch_dtype for HIGGS

* Update docs/source/en/quantization/higgs.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Collection link

* dequantize interface

* newer flute version, torch.compile support

* unittest message fix

* docs update compile

* isort

* ValueError instead of assert

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
2024-12-23 16:54:49 +01:00

3.1 KiB

HIGGS

HIGGS is a 0-shot quantization algorithm that combines Hadamard preprocessing with MSE-Optimal quantization grids to achieve lower quantization error and SOTA performance. You can find more information in the paper arxiv.org/abs/2411.17525.

Runtime support for HIGGS is implemented through FLUTE, and its library.

Quantization Example

from transformers import AutoModelForCausalLM, AutoTokenizer, HiggsConfig

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-9b-it",
    quantization_config=HiggsConfig(bits=4),
    device_map="auto",
)

tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b-it")

tokenizer.decode(model.generate(
    **tokenizer("Hi,", return_tensors="pt").to(model.device),
    temperature=0.5,
    top_p=0.80,
)[0])

Pre-quantized models

Some pre-quantized models can be found in the official collection on Hugging Face Hub.

Current Limitations

Architectures

Currently, FLUTE, and HIGGS by extension, only support Llama 3 and 3.0 of 8B, 70B and 405B parameters, as well as Gemma-2 9B and 27B. We're working on allowing to run more diverse models as well as allow arbitrary models by modifying the FLUTE compilation procedure.

torch.compile

HIGGS is fully compatible with torch.compile. Compiling model.forward, as described here, here're the speedups it provides on RTX 4090 for Llama-3.1-8B-Instruct (forward passes/sec):

Batch Size BF16 (With torch.compile) HIGGS 4bit (No torch.compile) HIGGS 4bit (With torch.compile)
1 59 41 124
4 57 42 123
16 56 41 120

Quantized training

Currently, HIGGS doesn't support quantized training (and backward passes in general). We're working on adding support for it.