mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Restructure torchao quantization examples (#37592)
* Restructure torchao quantization examples Summary: Mainly structured the examples by hardwares and then listed the recommended quantization methods for each hardware H100 GPU, A100 GPU and CPU Also added example for push_to_hub Test Plan: not required Reviewers: Subscribers: Tasks: Tags: * update * drop float8 cpu * address comments and simplify * small update * link update * minor update
This commit is contained in:
parent
006530d285
commit
7eb1107cc2
@ -33,10 +33,11 @@ See the table below for additional torchao features.
|
||||
|
||||
torchao supports the [quantization techniques](https://github.com/pytorch/ao/blob/main/torchao/quantization/README.md) below.
|
||||
|
||||
- A16W8 Int8 WeightOnly Quantization
|
||||
- A16W4 WeightOnly Quantization
|
||||
- A8W8 Int8 Dynamic Quantization
|
||||
- A16W8 Float8 Dynamic Quantization
|
||||
- A16W8 Float8 WeightOnly Quantization
|
||||
- A8W8 Int8 Dynamic Quantization
|
||||
- A16W8 Int8 Weight Only Quantization
|
||||
- A16W4 Int4 Weight Only Quantization
|
||||
- Autoquantization
|
||||
|
||||
|
||||
@ -44,7 +45,7 @@ Check the table below to see if your hardware is compatible.
|
||||
|
||||
| Component | Compatibility |
|
||||
|----------|----------------|
|
||||
| CUDA Versions | ✅ cu118, cu124, cu126, cu128 |
|
||||
| CUDA Versions | ✅ cu118, cu126, cu128 |
|
||||
| CPU | ✅ change `device_map="cpu"` (see examples below) |
|
||||
|
||||
|
||||
@ -56,14 +57,14 @@ Install torchao from PyPi or the PyTorch index with the following commands.
|
||||
|
||||
```bash
|
||||
# Updating 🤗 Transformers to the latest version, as the example script below uses the new auto compilation
|
||||
# Stable release from Pypi which will default to CUDA 12.4
|
||||
# Stable release from Pypi which will default to CUDA 12.6
|
||||
pip install --upgrade torchao transformers
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="PyTorch Index">
|
||||
Stable Release from the PyTorch index
|
||||
```bash
|
||||
pip install torchao --extra-index-url https://download.pytorch.org/whl/cu124 # options are cpu/cu118/cu124/cu126
|
||||
pip install torchao --index-url https://download.pytorch.org/whl/cu126 # options are cpu/cu118/cu126/cu128
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
@ -80,15 +81,19 @@ You can manually choose the quantization types and settings or automatically sel
|
||||
|
||||
Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize (for int8 weight only and int4 weight only). Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.
|
||||
|
||||
<hfoptions id="examples">
|
||||
<hfoption id="int8-weight-only cuda">
|
||||
We'll show examples for recommended quantization methods based on hardwares, e.g. A100 GPU, H100 GPU, CPU.
|
||||
|
||||
### H100 GPU
|
||||
<hfoptions id="examples-H100-GPU">
|
||||
<hfoption id="float8-dynamic-and-weight-only">
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
|
||||
|
||||
quant_config = Int8WeightOnlyConfig(group_size=128)
|
||||
quant_config = Float8DynamicActivationFloat8WeightConfig()
|
||||
# or float8 weight only quantization
|
||||
# quant_config = Float8WeightOnlyConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
@ -109,14 +114,114 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="int8-weight-only cpu">
|
||||
</hfoption>
|
||||
<hfoption id="int4-weight-only">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import GemliteUIntXWeightOnlyConfig
|
||||
|
||||
# We integrated with gemlite, which optimizes for batch size N on A100 and H100
|
||||
quant_config = GemliteUIntXWeightOnlyConfig(group_size=128)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### A100 GPU
|
||||
<hfoptions id="examples-A100-GPU">
|
||||
<hfoption id="int8-dynamic-and-weight-only">
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
quant_config = Int8WeightOnlyConfig(group_size=128)
|
||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||
# or int8 weight only quantization
|
||||
# quant_config = Int8WeightOnlyConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="int4-weight-only">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
# For batch size N, we recommend gemlite, which may require autotuning
|
||||
# default is 4 bit, 8 bit is also supported by passing `bit_width=8`
|
||||
quant_config = GemliteUIntXWeightOnlyConfig(group_size=128)
|
||||
|
||||
# For batch size 1, we also have custom tinygemm kernel that's only optimized for this
|
||||
# We can set `use_hqq` to `True` for better accuracy
|
||||
# quant_config = Int4WeightOnlyConfig(group_size=128, use_hqq=True)
|
||||
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### CPU
|
||||
<hfoptions id="examples-CPU">
|
||||
<hfoption id="int8-dynamic-and-weight-only">
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8WeightOnlyConfig
|
||||
|
||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||
# quant_config = Int8WeightOnlyConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
@ -136,35 +241,7 @@ output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implemen
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="int4-weight-only cuda">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
<hfoption id="int4-weight-only cpu">
|
||||
<hfoption id="int4-weight-only">
|
||||
|
||||
> [!TIP]
|
||||
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`.
|
||||
@ -195,116 +272,6 @@ output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implemen
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="int8-dynamic-quantization cuda">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
|
||||
|
||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="int8-dynamic-quantization cpu">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int8DynamicActivationInt8WeightConfig
|
||||
|
||||
quant_config = Int8DynamicActivationInt8WeightConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="cpu",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="float8-weight-only cuda">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Float8WeightOnlyConfig
|
||||
|
||||
quant_config = Float8WeightOnlyConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="float8-weight-only cpu">
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Float8WeightOnlyConfig
|
||||
|
||||
quant_config = Float8WeightOnlyConfig()
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Llama-3.1-8B-Instruct",
|
||||
torch_dtype="auto",
|
||||
device_map="cpu",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
</hfoptions>
|
||||
|
||||
### Autoquant
|
||||
@ -313,6 +280,8 @@ If you want to automatically choose a quantization type for quantizable layers (
|
||||
|
||||
The `autoquant` API automatically chooses a quantization type by micro-benchmarking on input type and shape and compiling a single linear layer.
|
||||
|
||||
Note: autoquant is for GPU only right now.
|
||||
|
||||
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
|
||||
|
||||
|
||||
@ -346,11 +315,24 @@ torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/not
|
||||
|
||||
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).
|
||||
|
||||
<hfoptions id="serialization-examples">
|
||||
<hfoption id="save-locally">
|
||||
```py
|
||||
# don't serialize model with Safetensors
|
||||
output_dir = "llama3-8b-int4wo-128"
|
||||
quantized_model.save_pretrained("llama3-8b-int4wo-128", safe_serialization=False)
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="push-to-huggingface-hub">
|
||||
```py
|
||||
# don't serialize model with Safetensors
|
||||
USER_ID = "your_huggingface_user_id"
|
||||
REPO_ID = "llama3-8b-int4wo-128"
|
||||
quantized_model.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128", safe_serialization=False)
|
||||
tokenizer.push_to_hub(f"{USER_ID}/llama3-8b-int4wo-128")
|
||||
```
|
||||
</hfoption>
|
||||
|
||||
|
||||
## Loading quantized models
|
||||
|
||||
@ -486,4 +468,4 @@ Refer to [Other Available Quantization Techniques](https://github.com/pytorch/ao
|
||||
|
||||
## Issues
|
||||
|
||||
If you encounter any issues with the Transformers integration, please open an issue on the [Transformers](https://github.com/huggingface/transformers/issues) repository. For issues directly related to torchao, please open an issue on the [torchao](https://github.com/pytorch/ao/issues) repository.
|
||||
If you encounter any issues with the Transformers integration, please open an issue on the [Transformers](https://github.com/huggingface/transformers/issues) repository. For issues directly related to torchao, please open an issue on the [torchao](https://github.com/pytorch/ao/issues) repository.
|
||||
|
Loading…
Reference in New Issue
Block a user