Restructure torchao quantization examples (#37592)

* Restructure torchao quantization examples

Summary:
Mainly structured the examples by hardwares and then listed
the recommended quantization methods for each hardware H100 GPU, A100 GPU and CPU

Also added example for push_to_hub

Test Plan:
not required

Reviewers:

Subscribers:

Tasks:

Tags:

* update

* drop float8 cpu

* address comments and simplify

* small update

* link update

* minor update
This commit is contained in:
Jerry Zhang 2025-04-22 02:20:34 -07:00 committed by GitHub
parent 006530d285
commit 7eb1107cc2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -33,10 +33,11 @@ See the table below for additional torchao features.
torchao supports the [quantization techniques](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) below.
- A16W8 Int8 WeightOnly Quantization
- A16W4 WeightOnly Quantization
- A8W8 Int8 Dynamic Quantization
- A16W8 Float8 Dynamic Quantization
- A16W8 Float8 WeightOnly Quantization
- A8W8 Int8 Dynamic Quantization
- A16W8 Int8 Weight Only Quantization
- A16W4 Int4 Weight Only Quantization
- Autoquantization
@ -44,7 +45,7 @@ Check the table below to see if your hardware is compatible.
| Component | Compatibility |
|----------|----------------|
| CUDA Versions | ✅ cu118, cu124, cu126, cu128 |
| CUDA Versions | ✅ cu118, cu126, cu128 |
| CPU | ✅ change `device_map="cpu"` (see examples below) |
@ -56,14 +57,14 @@ Install torchao from PyPi or the PyTorch index with the following commands.
```bash
# Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation
# Stable release from Pypi which will default to CUDA 12.4
# Stable release from Pypi which will default to CUDA 12.6
pip install --upgrade torchao transformers
```
</hfoption>
<hfoption id="PyTorch Index">
Stable Release from the PyTorch index
```bash
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu124 # options are cpu/cu118/cu124/cu126
pip install torchao --index-url https://download.pytorch.org/whl/cu126 # options are cpu/cu118/cu126/cu128
```
</hfoption>
</hfoptions>
@ -80,15 +81,19 @@ You can manually choose the quantization types and settings or automatically sel
Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize (for int8 weight only and int4 weight only). Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.
<hfoptions id="examples">
<hfoption id="int8-weight-only cuda">
We'll show examples for recommended quantization methods based on hardwares, e.g. A100 GPU, H100 GPU, CPU.
### H100 GPU
<hfoptions id="examples-H100-GPU">
<hfoption id="float8-dynamic-and-weight-only">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8WeightOnlyConfig
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
quant_config = Int8WeightOnlyConfig(group_size=128)
quant_config = Float8DynamicActivationFloat8WeightConfig()
# or float8 weight only quantization
# quant_config = Float8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
@ -109,14 +114,114 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="int8-weight-only cpu">
</hfoption>
<hfoption id="int4-weight-only">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import GemliteUIntXWeightOnlyConfig
# We integrated with gemlite, which optimizes for batch size N on A100 and H100
quant_config = GemliteUIntXWeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
</hfoptions>
### A100 GPU
<hfoptions id="examples-A100-GPU">
<hfoption id="int8-dynamic-and-weight-only">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8WeightOnlyConfig
quant_config = Int8WeightOnlyConfig(group_size=128)
quant_config = Int8DynamicActivationInt8WeightConfig()
# or int8 weight only quantization
# quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="int4-weight-only">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
# For batch size N, we recommend gemlite, which may require autotuning
# default is 4 bit, 8 bit is also supported by passing `bit_width=8`
quant_config = GemliteUIntXWeightOnlyConfig(group_size=128)
# For batch size 1, we also have custom tinygemm kernel that's only optimized for this
# We can set `use_hqq` to `True` for better accuracy
# quant_config = Int4WeightOnlyConfig(group_size=128, use_hqq=True)
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
</hfoptions>
### CPU
<hfoptions id="examples-CPU">
<hfoption id="int8-dynamic-and-weight-only">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8WeightOnlyConfig
quant_config = Int8DynamicActivationInt8WeightConfig()
# quant_config = Int8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
@ -136,35 +241,7 @@ output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implemen
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="int4-weight-only cuda">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
quant_config = Int4WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="int4-weight-only cpu">
<hfoption id="int4-weight-only">
> [!TIP]
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`.
@ -195,116 +272,6 @@ output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implemen
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="int8-dynamic-quantization cuda">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
quant_config = Int8DynamicActivationInt8WeightConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="int8-dynamic-quantization cpu">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
quant_config = Int8DynamicActivationInt8WeightConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="cpu",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="float8-weight-only cuda">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Float8WeightOnlyConfig
quant_config = Float8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
<hfoption id="float8-weight-only cpu">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Float8WeightOnlyConfig
quant_config = Float8WeightOnlyConfig()
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.1-8B-Instruct",
torch_dtype="auto",
device_map="cpu",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
</hfoptions>
### Autoquant
@ -313,6 +280,8 @@ If you want to automatically choose a quantization type for quantizable layers (
The `autoquant` API automatically chooses a quantization type by micro-benchmarking on input type and shape and compiling a single linear layer.
Note: autoquant is for GPU only right now.
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
@ -346,11 +315,24 @@ torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/not
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).
<hfoptions id="serialization-examples">
<hfoption id="save-locally">
```py
# don't serialize model with Safetensors
output_dir = "llama3-8b-int4wo-128"
quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False)
```
</hfoption>
<hfoption id="push-to-huggingface-hub">
```py
# don't serialize model with Safetensors
USER_ID = "your_huggingface_user_id"
REPO_ID = "llama3-8b-int4wo-128"
quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128", safe_serialization=False)
tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
```
</hfoption>
## Loading quantized models
@ -486,4 +468,4 @@ Refer to [Other Available Quantization Techniques](https://github.com/pytorch/ao
## Issues
If you encounter any issues with the Transformers integration, please open an issue on the [Transformers](https://github.com/huggingface/transformers/issues) repository. For issues directly related to torchao, please open an issue on the [torchao](https://github.com/pytorch/ao/issues) repository.
If you encounter any issues with the Transformers integration, please open an issue on the [Transformers](https://github.com/huggingface/transformers/issues) repository. For issues directly related to torchao, please open an issue on the [torchao](https://github.com/pytorch/ao/issues) repository.