# torchao [torchao](https://github.com/pytorch/ao) is a PyTorch architecture optimization library with support for custom high performance data types, quantization, and sparsity. It is composable with native PyTorch features such as [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) for even faster inference and training. Install torchao with the following command. ```bash # Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation pip install --upgrade torch torchao transformers ``` torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights. You can manually choose the quantization types and settings or automatically select the quantization types. Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. > [!TIP] > Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+. ```py import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer quantization_config = TorchAoConfig("int4_weight_only", group_size=128) quantized_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3-8B", torch_dtype="auto", device_map="auto", quantization_config=quantization_config ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") input_text = "What are we having for dinner?" input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") # auto-compile the quantized model with `cache_implementation="static"` to get speed up output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") print(tokenizer.decode(output[0], skip_special_tokens=True)) ``` Run the code below to benchmark the quantized models performance. ```py from torch._inductor.utils import do_bench_using_profiling from typing import Callable def benchmark_fn(func: Callable, *args, **kwargs) -> float: """Thin wrapper around do_bench_using_profiling""" no_args = lambda: func(*args, **kwargs) time = do_bench_using_profiling(no_args) return time * 1e3 MAX_NEW_TOKENS = 1000 print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16) output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) ``` The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer. Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes. > [!TIP] > Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+. ```py import torch from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer quantization_config = TorchAoConfig("autoquant", min_sqnr=None) quantized_model = AutoModelForCausalLM.from_pretrained( "meta-llama/Meta-Llama-3-8B", torch_dtype="auto", device_map="auto", quantization_config=quantization_config ) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B") input_text = "What are we having for dinner?" input_ids = tokenizer(input_text, return_tensors="pt").to("cuda") # auto-compile the quantized model with `cache_implementation="static"` to get speed up output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # explicitly call `finalize_autoquant` (may be refactored and removed in the future) quantized_model.finalize_autoquant() print(tokenizer.decode(output[0], skip_special_tokens=True)) ``` Run the code below to benchmark the quantized models performance. ```py from torch._inductor.utils import do_bench_using_profiling from typing import Callable def benchmark_fn(func: Callable, *args, **kwargs) -> float: """Thin wrapper around do_bench_using_profiling""" no_args = lambda: func(*args, **kwargs) time = do_bench_using_profiling(no_args) return time * 1e3 MAX_NEW_TOKENS = 1000 print("autoquantized model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.bfloat16) output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")) ``` ## Serialization torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco. To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals). ```py # don't serialize model with Safetensors output_dir = "llama3-8b-int4wo-128" quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False) ``` ## Resources For a better sense of expected performance, view the [benchmarks](https://github.com/pytorch/ao/tree/main/torchao/quantization#benchmarks) for various models with CUDA and XPU backends. Refer to [Other Available Quantization Techniques](https://github.com/pytorch/ao/tree/main/torchao/quantization#other-available-quantization-techniques) for more examples and documentation.