[docs] Add int4wo + 2:4 sparsity example to TorchAO README (#38592)
Some checks are pending
Self-hosted runner (benchmark) / Benchmark (aws-g5-4xlarge-cache) (push) Waiting to run
Build documentation / build (push) Waiting to run
Slow tests on important models (on Push - A10) / Get all modified files (push) Waiting to run
Slow tests on important models (on Push - A10) / Slow & FA2 tests (push) Blocked by required conditions
Self-hosted runner (push-caller) / Check if setup was changed (push) Waiting to run
Self-hosted runner (push-caller) / build-docker-containers (push) Blocked by required conditions
Self-hosted runner (push-caller) / Trigger Push CI (push) Blocked by required conditions
Secret Leaks / trufflehog (push) Waiting to run
Update Transformers metadata / build_and_package (push) Waiting to run

* update quantization readme

* update

---------

Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com>
This commit is contained in:
Jesse Cai 2025-06-12 08:17:07 -04:00 committed by GitHub
parent bc68defcac
commit e1812864ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -38,6 +38,7 @@ torchao supports the [quantization techniques](https://github.com/pytorch/ao/blo
- A8W8 Int8 Dynamic Quantization
- A16W8 Int8 Weight Only Quantization
- A16W4 Int4 Weight Only Quantization
- A16W4 Int4 Weight Only Quantization + 2:4 Sparsity
- Autoquantization
torchao also supports module level configuration by specifying a dictionary from fully qualified name of module and its corresponding quantization config. This allows skip quantizing certain layers and using different quantization config for different modules.
@ -147,6 +148,37 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption>
</hfoptions>
</hfoption>
<hfoption id="int4-weight-only-24sparse">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import MarlinSparseLayout
quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout())
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuraccy loss
quantized_model = AutoModelForCausalLM.from_pretrained(
"RedHatAI/Sparse-Llama-3.1-8B-2of4",
torch_dtype=torch.float16,
device_map="cuda",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("RedHatAI/Sparse-Llama-3.1-8B-2of4")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
</hfoptions>
### A100 GPU
<hfoptions id="examples-A100-GPU">
<hfoption id="int8-dynamic-and-weight-only">
@ -215,6 +247,37 @@ print(tokenizer.decode(output[0], skip_special_tokens=True))
</hfoption>
</hfoptions>
</hfoption>
<hfoption id="int4-weight-only-24sparse">
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
from torchao.dtypes import MarlinSparseLayout
quant_config = Int4WeightOnlyConfig(layout=MarlinSparseLayout())
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model with sparsity. A sparse checkpoint is needed to accelerate without accuraccy loss
quantized_model = AutoModelForCausalLM.from_pretrained(
"RedHatAI/Sparse-Llama-3.1-8B-2of4",
torch_dtype=torch.float16,
device_map="cuda",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("RedHatAI/Sparse-Llama-3.1-8B-2of4")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
</hfoption>
</hfoptions>
### CPU
<hfoptions id="examples-CPU">
<hfoption id="int8-dynamic-and-weight-only">