Add autoquant support for torchao quantizer (#35503)

* Add autoquant support for torchao quantizer

Summary:
att, also verified that autoquantized model can be saved and loaded:

save: https://gist.github.com/jerryzh168/01d367aaf44dbbbfd4068a4a10a00061
load: https://gist.github.com/jerryzh168/d5c6c401b2abdf18e0b6771341f1525c

Test Plan:
tested locally with above script
model uploaded to https://huggingface.co/jerryzh168/llama3-8b-autoquant

Reviewers:

Subscribers:

Tasks:

Tags:

* add test

* ruff fix

* ruff reformat

* add docs and min_sqnr support

* format

* format

* fix test

* update doc

* format

* remove disable_compile

* format
This commit is contained in:
Jerry Zhang 2025-02-24 06:54:16 -08:00 committed by GitHub
parent 977a61f743
commit 2af272c101
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 133 additions and 17 deletions

View File

@ -22,6 +22,12 @@ pip install --upgrade torch torchao transformers
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
## Manually Choose Quantization Types and Settings
`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed.
Users can manually specify the quantization types and settings they want to use:
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
@ -41,19 +47,14 @@ output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implemen
print(tokenizer.decode(output[0], skip_special_tokens=True))
# benchmark the performance
import torch.utils.benchmark as benchmark
from torch._inductor.utils import do_bench_using_profiling
from typing import Callable
def benchmark_fn(f, *args, **kwargs):
# Manual warmup
for _ in range(5):
f(*args, **kwargs)
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)",
globals={"args": args, "kwargs": kwargs, "f": f},
num_threads=torch.get_num_threads(),
)
return f"{(t0.blocked_autorange().mean):.3f}"
def benchmark_fn(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3
MAX_NEW_TOKENS = 1000
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
@ -64,6 +65,47 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
```
## Automatically Select Quantization Types
`torchao` also provies `autoquant` feature that automatically chooses a quantization type for quantizable layers such as linear based on microbenchmarks of quantizing and compiling a single linear layer.
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
model_name = "meta-llama/Meta-Llama-3-8B"
quantization_config = TorchAoConfig("autoquant", min_sqnr=None)
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speedup
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
# Due to some implementation details we are explicitly calling this now, we may refactor our code and remove this in the future
quantized_model.finalize_autoquant()
print(tokenizer.decode(output[0], skip_special_tokens=True))
# benchmark the performance
from torch._inductor.utils import do_bench_using_profiling
from typing import Callable
def benchmark_fn(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3
MAX_NEW_TOKENS = 1000
print("autoquantized model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
```
## Serialization and Deserialization
torchao quantization is implemented with [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor), it only work with huggingface non-safetensor serialization and deserialization. It relies on `torch.load(..., weights_only=True)` to avoid arbitrary user code execution during load time and use [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) to allowlist some known user functions.

View File

@ -4914,7 +4914,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
device_map is not None
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
):
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
state_dict = load_state_dict(

View File

@ -129,6 +129,7 @@ class TorchAoHfQuantizer(HfQuantizer):
"int4_weight_only": CustomDtype.INT4,
"int8_weight_only": torch.int8,
"int8_dynamic_activation_int8_weight": torch.int8,
"autoquant": None,
}
return map_to_target_dtype[self.quantization_config.quant_type]
else:
@ -161,6 +162,9 @@ class TorchAoHfQuantizer(HfQuantizer):
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
if self.quantization_config.quant_type == "autoquant":
return False
param_device = kwargs.pop("param_device", None)
# check if the param_name is not in self.modules_to_not_convert
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
@ -186,6 +190,9 @@ class TorchAoHfQuantizer(HfQuantizer):
Each nn.Linear layer that needs to be quantized is processsed here.
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
"""
if self.quantization_config.quant_type == "autoquant":
return
from torchao.quantization import quantize_
module, tensor_name = get_module_from_name(model, param_name)
@ -200,6 +207,15 @@ class TorchAoHfQuantizer(HfQuantizer):
def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
if self.quantization_config.quant_type == "autoquant":
from torchao import autoquant
from torchao.quantization import ALL_AUTOQUANT_CLASS_LIST
model = torch.compile(model, mode="max-autotune")
model = autoquant(
model, qtensor_class_list=ALL_AUTOQUANT_CLASS_LIST, **self.quantization_config.quant_type_kwargs
)
return model
return
def is_serializable(self, safe_serialization=None):

View File

@ -1453,7 +1453,7 @@ class TorchAoConfig(QuantizationConfigMixin):
Args:
quant_type (`str`):
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` and `autoquant`.
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision.
@ -1465,9 +1465,31 @@ class TorchAoConfig(QuantizationConfigMixin):
Example:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
# specific quantization method
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
# autoquant
# `autoquant` is a convenient way for users to search for the best quantization for each layer
# `min_sqnr` is an option to control the accuracy of the model, higher value means the model is more
# accurate, we can start with 30 and adjust it to larger or smaller (e.g. 40, 20)
# defaults to None, which means we'll try to get the best performing quantized model without
# considering accuracy
quantization_config = TorchAoConfig("autoquant", min_sqnr=30)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
# run through example inputs, quantization methods will be selected based on the shape of example input
tokenizer = AutoTokenizer.from_pretrained(model_name)
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
MAX_NEW_TOKENS = 1000
model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")
# manually ran finalize_autoquant if needed
if hasattr(quantized_model, "finalize_autoquant"):
print("finalizing autoquant")
quantized_model.finalize_autoquant()
```
"""
@ -1488,8 +1510,8 @@ class TorchAoConfig(QuantizationConfigMixin):
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if is_torchao_available():
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
raise ValueError("Requires torchao 0.4.0 version and above")
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.7.0"):
raise ValueError("Requires torchao 0.7.0 version and above")
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
@ -1517,6 +1539,7 @@ class TorchAoConfig(QuantizationConfigMixin):
def _get_torchao_quant_type_to_method(self):
if is_torchao_available():
from torchao.quantization import (
autoquant,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
@ -1526,6 +1549,7 @@ class TorchAoConfig(QuantizationConfigMixin):
"int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
"autoquant": autoquant,
}
else:
raise ValueError(

View File

@ -31,10 +31,12 @@ if is_torch_available():
import torch
if is_torchao_available():
# renamed in torchao 0.7.0, please install the latest torchao
from torchao.dtypes import (
AffineQuantizedTensor,
TensorCoreTiledLayout,
)
from torchao.quantization.autoquant import AQMixin
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
@ -42,7 +44,12 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
test_module.assertEqual(weight.quant_min, 0)
test_module.assertEqual(weight.quant_max, 15)
test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout))
test_module.assertTrue(isinstance(weight._layout, TensorCoreTiledLayout))
def check_autoquantized(test_module, qlayer):
weight = qlayer.weight
test_module.assertTrue(isinstance(weight, AQMixin))
def check_forward(test_module, model, batch_size=1, context_size=1024):
@ -248,6 +255,33 @@ class TorchAoTest(unittest.TestCase):
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
def test_autoquant(self):
"""
Simple LLM model testing autoquant
"""
quant_config = TorchAoConfig("autoquant")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name,
torch_dtype=torch.bfloat16,
device_map=torch_device,
quantization_config=quant_config,
)
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = quantized_model.generate(
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
)
quantized_model.finalize_autoquant()
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready'
output = quantized_model.generate(
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
)
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
@require_torch_gpu
@require_torchao