transformers/docs/source/en/quantization/torchao.md
Steven Liu c0f8d055ce
[docs] Redesign (#31757)
* toctree

* not-doctested.txt

* collapse sections

* feedback

* update

* rewrite get started sections

* fixes

* fix

* loading models

* fix

* customize models

* share

* fix link

* contribute part 1

* contribute pt 2

* fix toctree

* tokenization pt 1

* Add new model (#32615)

* v1 - working version

* fix

* fix

* fix

* fix

* rename to correct name

* fix title

* fixup

* rename files

* fix

* add copied from on tests

* rename to `FalconMamba` everywhere and fix bugs

* fix quantization + accelerate

* fix copies

* add `torch.compile` support

* fix tests

* fix tests and add slow tests

* copies on config

* merge the latest changes

* fix tests

* add few lines about instruct

* Apply suggestions from code review

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* fix

* fix tests

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* "to be not" -> "not to be" (#32636)

* "to be not" -> "not to be"

* Update sam.md

* Update trainer.py

* Update modeling_utils.py

* Update test_modeling_utils.py

* Update test_modeling_utils.py

* fix hfoption tag

* tokenization pt. 2

* image processor

* fix toctree

* backbones

* feature extractor

* fix file name

* processor

* update not-doctested

* update

* make style

* fix toctree

* revision

* make fixup

* fix toctree

* fix

* make style

* fix hfoption tag

* pipeline

* pipeline gradio

* pipeline web server

* add pipeline

* fix toctree

* not-doctested

* prompting

* llm optims

* fix toctree

* fixes

* cache

* text generation

* fix

* chat pipeline

* chat stuff

* xla

* torch.compile

* cpu inference

* toctree

* gpu inference

* agents and tools

* gguf/tiktoken

* finetune

* toctree

* trainer

* trainer pt 2

* optims

* optimizers

* accelerate

* parallelism

* fsdp

* update

* distributed cpu

* hardware training

* gpu training

* gpu training 2

* peft

* distrib debug

* deepspeed 1

* deepspeed 2

* chat toctree

* quant pt 1

* quant pt 2

* fix toctree

* fix

* fix

* quant pt 3

* quant pt 4

* serialization

* torchscript

* scripts

* tpu

* review

* model addition timeline

* modular

* more reviews

* reviews

* fix toctree

* reviews reviews

* continue reviews

* more reviews

* modular transformers

* more review

* zamba2

* fix

* all frameworks

* pytorch

* supported model frameworks

* flashattention

* rm check_table

* not-doctested.txt

* rm check_support_list.py

* feedback

* updates/feedback

* review

* feedback

* fix

* update

* feedback

* updates

* update

---------

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
2025-03-03 10:33:46 -08:00

7.6 KiB

torchao

torchao is a PyTorch architecture optimization library with support for custom high performance data types, quantization, and sparsity. It is composable with native PyTorch features such as torch.compile for even faster inference and training.

Install torchao with the following command.

# Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation
pip install --upgrade torch torchao transformers

torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights.

You can manually choose the quantization types and settings or automatically select the quantization types.

Create a [TorchAoConfig] and specify the quantization type and group_size of the weights to quantize. Set the cache_implementation to "static" to automatically torch.compile the forward method.

Tip

Run the quantized model on a CPU by changing device_map to "cpu" and layout to Int4CPULayout(). This is only available in torchao 0.8.0+.

import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer

quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    torch_dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))

Run the code below to benchmark the quantized models performance.

from torch._inductor.utils import do_bench_using_profiling
from typing import Callable

def benchmark_fn(func: Callable, *args, **kwargs) -> float:
    """Thin wrapper around do_bench_using_profiling"""
    no_args = lambda: func(*args, **kwargs)
    time = do_bench_using_profiling(no_args)
    return time * 1e3

MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

The autoquant API automatically chooses a quantization type for quantizable layers (nn.Linear) by micro-benchmarking on input type and shape and compiling a single linear layer.

Create a [TorchAoConfig] and set to "autoquant". Set the cache_implementation to "static" to automatically torch.compile the forward method. Finally, call finalize_autoquant on the quantized model to finalize the quantization and log the input shapes.

Tip

Run the quantized model on a CPU by changing device_map to "cpu" and layout to Int4CPULayout(). This is only available in torchao 0.8.0+.

import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer

quantization_config = TorchAoConfig("autoquant", min_sqnr=None)
quantized_model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3-8B",
    torch_dtype="auto",
    device_map="auto",
    quantization_config=quantization_config
)

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
# explicitly call `finalize_autoquant` (may be refactored and removed in the future)
quantized_model.finalize_autoquant()
print(tokenizer.decode(output[0], skip_special_tokens=True))

Run the code below to benchmark the quantized models performance.

from torch._inductor.utils import do_bench_using_profiling
from typing import Callable

def benchmark_fn(func: Callable, *args, **kwargs) -> float:
    """Thin wrapper around do_bench_using_profiling"""
    no_args = lambda: func(*args, **kwargs)
    time = do_bench_using_profiling(no_args)
    return time * 1e3

MAX_NEW_TOKENS = 1000
print("autoquantized model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))

Serialization

torchao implements torch.Tensor subclasses for maximum flexibility in supporting new quantized torch.Tensor formats. Safetensors serialization and deserialization does not work with torchaco.

To avoid arbitrary user code execution, torchao sets weights_only=True in torch.load to ensure only tensors are loaded. Any known user functions can be whitelisted with add_safe_globals.

# don't serialize model with Safetensors
output_dir = "llama3-8b-int4wo-128"
quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False)

Resources

For a better sense of expected performance, view the benchmarks for various models with CUDA and XPU backends.

Refer to Other Available Quantization Techniques for more examples and documentation.