mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add autoquant support for torchao quantizer (#35503)
* Add autoquant support for torchao quantizer Summary: att, also verified that autoquantized model can be saved and loaded: save: https://gist.github.com/jerryzh168/01d367aaf44dbbbfd4068a4a10a00061 load: https://gist.github.com/jerryzh168/d5c6c401b2abdf18e0b6771341f1525c Test Plan: tested locally with above script model uploaded to https://huggingface.co/jerryzh168/llama3-8b-autoquant Reviewers: Subscribers: Tasks: Tags: * add test * ruff fix * ruff reformat * add docs and min_sqnr support * format * format * fix test * update doc * format * remove disable_compile * format
This commit is contained in:
parent
977a61f743
commit
2af272c101
@ -22,6 +22,12 @@ pip install --upgrade torch torchao transformers
|
||||
|
||||
By default, the weights are loaded in full precision (torch.float32) regardless of the actual data type the weights are stored in such as torch.float16. Set `torch_dtype="auto"` to load the weights in the data type defined in a model's `config.json` file to automatically load the most memory-optimal data type.
|
||||
|
||||
## Manually Choose Quantization Types and Settings
|
||||
|
||||
`torchao` Provides many commonly used types of quantization, including different dtypes like int4, float8 and different flavors like weight only, dynamic quantization etc., only `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight` are integrated into hugigngface transformers currently, but we can add more when needed.
|
||||
|
||||
Users can manually specify the quantization types and settings they want to use:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@ -41,19 +47,14 @@ output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implemen
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
# benchmark the performance
|
||||
import torch.utils.benchmark as benchmark
|
||||
from torch._inductor.utils import do_bench_using_profiling
|
||||
from typing import Callable
|
||||
|
||||
def benchmark_fn(f, *args, **kwargs):
|
||||
# Manual warmup
|
||||
for _ in range(5):
|
||||
f(*args, **kwargs)
|
||||
|
||||
t0 = benchmark.Timer(
|
||||
stmt="f(*args, **kwargs)",
|
||||
globals={"args": args, "kwargs": kwargs, "f": f},
|
||||
num_threads=torch.get_num_threads(),
|
||||
)
|
||||
return f"{(t0.blocked_autorange().mean):.3f}"
|
||||
def benchmark_fn(func: Callable, *args, **kwargs) -> float:
|
||||
"""Thin wrapper around do_bench_using_profiling"""
|
||||
no_args = lambda: func(*args, **kwargs)
|
||||
time = do_bench_using_profiling(no_args)
|
||||
return time * 1e3
|
||||
|
||||
MAX_NEW_TOKENS = 1000
|
||||
print("int4wo-128 model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
|
||||
@ -64,6 +65,47 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
|
||||
|
||||
```
|
||||
|
||||
## Automatically Select Quantization Types
|
||||
|
||||
`torchao` also provies `autoquant` feature that automatically chooses a quantization type for quantizable layers such as linear based on microbenchmarks of quantizing and compiling a single linear layer.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
|
||||
model_name = "meta-llama/Meta-Llama-3-8B"
|
||||
quantization_config = TorchAoConfig("autoquant", min_sqnr=None)
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype="auto", device_map="auto", quantization_config=quantization_config)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speedup
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
# Due to some implementation details we are explicitly calling this now, we may refactor our code and remove this in the future
|
||||
quantized_model.finalize_autoquant()
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
|
||||
# benchmark the performance
|
||||
from torch._inductor.utils import do_bench_using_profiling
|
||||
from typing import Callable
|
||||
|
||||
def benchmark_fn(func: Callable, *args, **kwargs) -> float:
|
||||
"""Thin wrapper around do_bench_using_profiling"""
|
||||
no_args = lambda: func(*args, **kwargs)
|
||||
time = do_bench_using_profiling(no_args)
|
||||
return time * 1e3
|
||||
|
||||
MAX_NEW_TOKENS = 1000
|
||||
print("autoquantized model:", benchmark_fn(quantized_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
|
||||
|
||||
bf16_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="cuda", torch_dtype=torch.bfloat16)
|
||||
output = bf16_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static") # auto-compile
|
||||
print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static"))
|
||||
|
||||
```
|
||||
|
||||
## Serialization and Deserialization
|
||||
torchao quantization is implemented with [tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor), it only work with huggingface non-safetensor serialization and deserialization. It relies on `torch.load(..., weights_only=True)` to avoid arbitrary user code execution during load time and use [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals) to allowlist some known user functions.
|
||||
|
||||
|
@ -4914,7 +4914,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
device_map is not None
|
||||
and hf_quantizer is not None
|
||||
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO
|
||||
and hf_quantizer.quantization_config.quant_type == "int4_weight_only"
|
||||
and hf_quantizer.quantization_config.quant_type in ["int4_weight_only", "autoquant"]
|
||||
):
|
||||
map_location = torch.device([d for d in device_map.values() if d not in ["cpu", "disk"]][0])
|
||||
state_dict = load_state_dict(
|
||||
|
@ -129,6 +129,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
"int4_weight_only": CustomDtype.INT4,
|
||||
"int8_weight_only": torch.int8,
|
||||
"int8_dynamic_activation_int8_weight": torch.int8,
|
||||
"autoquant": None,
|
||||
}
|
||||
return map_to_target_dtype[self.quantization_config.quant_type]
|
||||
else:
|
||||
@ -161,6 +162,9 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
state_dict: Dict[str, Any],
|
||||
**kwargs,
|
||||
) -> bool:
|
||||
if self.quantization_config.quant_type == "autoquant":
|
||||
return False
|
||||
|
||||
param_device = kwargs.pop("param_device", None)
|
||||
# check if the param_name is not in self.modules_to_not_convert
|
||||
if any((key + "." in param_name) or (key == param_name) for key in self.modules_to_not_convert):
|
||||
@ -186,6 +190,9 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
Each nn.Linear layer that needs to be quantized is processsed here.
|
||||
First, we set the value the weight tensor, then we move it to the target device. Finally, we quantize the module.
|
||||
"""
|
||||
if self.quantization_config.quant_type == "autoquant":
|
||||
return
|
||||
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
@ -200,6 +207,15 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
"""No process required for torchao quantized model"""
|
||||
if self.quantization_config.quant_type == "autoquant":
|
||||
from torchao import autoquant
|
||||
from torchao.quantization import ALL_AUTOQUANT_CLASS_LIST
|
||||
|
||||
model = torch.compile(model, mode="max-autotune")
|
||||
model = autoquant(
|
||||
model, qtensor_class_list=ALL_AUTOQUANT_CLASS_LIST, **self.quantization_config.quant_type_kwargs
|
||||
)
|
||||
return model
|
||||
return
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
|
@ -1453,7 +1453,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
|
||||
Args:
|
||||
quant_type (`str`):
|
||||
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
|
||||
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` and `autoquant`.
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||
some modules left in their original precision.
|
||||
@ -1465,9 +1465,31 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
Example:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
|
||||
# specific quantization method
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||
|
||||
# autoquant
|
||||
# `autoquant` is a convenient way for users to search for the best quantization for each layer
|
||||
# `min_sqnr` is an option to control the accuracy of the model, higher value means the model is more
|
||||
# accurate, we can start with 30 and adjust it to larger or smaller (e.g. 40, 20)
|
||||
# defaults to None, which means we'll try to get the best performing quantized model without
|
||||
# considering accuracy
|
||||
quantization_config = TorchAoConfig("autoquant", min_sqnr=30)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||
# run through example inputs, quantization methods will be selected based on the shape of example input
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
MAX_NEW_TOKENS = 1000
|
||||
model.generate(**input_ids, max_new_tokens=MAX_NEW_TOKENS, cache_implementation="static")
|
||||
# manually ran finalize_autoquant if needed
|
||||
if hasattr(quantized_model, "finalize_autoquant"):
|
||||
print("finalizing autoquant")
|
||||
quantized_model.finalize_autoquant()
|
||||
```
|
||||
"""
|
||||
|
||||
@ -1488,8 +1510,8 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
"""
|
||||
if is_torchao_available():
|
||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.4.0"):
|
||||
raise ValueError("Requires torchao 0.4.0 version and above")
|
||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.7.0"):
|
||||
raise ValueError("Requires torchao 0.7.0 version and above")
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
@ -1517,6 +1539,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
def _get_torchao_quant_type_to_method(self):
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import (
|
||||
autoquant,
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
@ -1526,6 +1549,7 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
"int4_weight_only": int4_weight_only,
|
||||
"int8_weight_only": int8_weight_only,
|
||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||
"autoquant": autoquant,
|
||||
}
|
||||
else:
|
||||
raise ValueError(
|
||||
|
@ -31,10 +31,12 @@ if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_torchao_available():
|
||||
# renamed in torchao 0.7.0, please install the latest torchao
|
||||
from torchao.dtypes import (
|
||||
AffineQuantizedTensor,
|
||||
TensorCoreTiledLayout,
|
||||
)
|
||||
from torchao.quantization.autoquant import AQMixin
|
||||
|
||||
|
||||
def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024):
|
||||
@ -42,7 +44,12 @@ def check_torchao_quantized(test_module, qlayer, batch_size=1, context_size=1024
|
||||
test_module.assertTrue(isinstance(weight, AffineQuantizedTensor))
|
||||
test_module.assertEqual(weight.quant_min, 0)
|
||||
test_module.assertEqual(weight.quant_max, 15)
|
||||
test_module.assertTrue(isinstance(weight.layout, TensorCoreTiledLayout))
|
||||
test_module.assertTrue(isinstance(weight._layout, TensorCoreTiledLayout))
|
||||
|
||||
|
||||
def check_autoquantized(test_module, qlayer):
|
||||
weight = qlayer.weight
|
||||
test_module.assertTrue(isinstance(weight, AQMixin))
|
||||
|
||||
|
||||
def check_forward(test_module, model, batch_size=1, context_size=1024):
|
||||
@ -248,6 +255,33 @@ class TorchAoTest(unittest.TestCase):
|
||||
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
def test_autoquant(self):
|
||||
"""
|
||||
Simple LLM model testing autoquant
|
||||
"""
|
||||
quant_config = TorchAoConfig("autoquant")
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=torch.bfloat16,
|
||||
device_map=torch_device,
|
||||
quantization_config=quant_config,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
input_ids = tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
output = quantized_model.generate(
|
||||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
||||
)
|
||||
quantized_model.finalize_autoquant()
|
||||
|
||||
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
|
||||
|
||||
EXPECTED_OUTPUT = 'What are we having for dinner?\n\n10. "Dinner is ready'
|
||||
output = quantized_model.generate(
|
||||
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
|
||||
)
|
||||
self.assertEqual(tokenizer.decode(output[0], skip_special_tokens=True), EXPECTED_OUTPUT)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao
|
||||
|
Loading…
Reference in New Issue
Block a user