Add option for ao base configs (#36526)

This commit is contained in:
Driss Guessous 2025-03-19 06:59:47 -07:00 committed by GitHub
parent fef8b7f8e9
commit e8d960329e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 293 additions and 87 deletions

View File

@ -20,18 +20,95 @@ Install torchao with the following command.
pip install --upgrade torch torchao transformers
```
torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights.
torchao supports many quantization types for different data types (int4, float8, weight only, etc.).
Starting with version 0.10.0, torchao provides enhanced flexibility through the `AOBaseConfig` API, allowing for more customized quantization configurations.
And full access to the techniques offered in the torchao library.
You can manually choose the quantization types and settings or automatically select the quantization types.
<hfoptions id="torchao">
<hfoption id="manual">
Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.
> [!TIP]
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
In torchao 0.10.0+, you can use the more flexible `AOBaseConfig` approach instead of string identifiers:
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
# Using AOBaseConfig instance (torchao >= 0.10.0)
quant_config = Int4WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
## Available Quantization Schemes
TorchAO provides a variety of quantization configurations:
- `Int4WeightOnlyConfig`
- `Int8WeightOnlyConfig`
- `Int8DynamicActivationInt8WeightConfig`
- `Float8WeightOnlyConfig`
Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures.
For a complete list of available configurations, see our [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py).
> **⚠️ DEPRECATION WARNING**
>
> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release.
>
> Please use the new `AOBaseConfig`-based approach instead:
>
> ```python
> # Old way (deprecated)
> quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
>
> # New way (recommended)
> from torchao.quantization import Int4WeightOnlyConfig
> quant_config = Int4WeightOnlyConfig(group_size=128)
> quantization_config = TorchAoConfig(quant_type=quant_config)
> ```
>
> The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.
>
> ## Migration Guide
>
> Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents:
>
> | Old String API | New `AOBaseConfig` API |
> |----------------|------------------------|
> | `"int4_weight_only"` | `Int4WeightOnlyConfig()` |
> | `"int8_weight_only"` | `Int8WeightOnlyConfig()` |
> | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` |
>
> All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`).
Below is the API for for torchao < `0.9.0`
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
@ -78,7 +155,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer.
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
> [!TIP]
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
@ -131,7 +208,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
## Serialization
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco.
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao.
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import re
import types
from typing import TYPE_CHECKING, Optional, Union
@ -27,6 +28,7 @@ if TYPE_CHECKING:
from typing import Any, Dict, List
from ..utils import is_torch_available, is_torchao_available, logging
from ..utils.quantization_config import TorchAoConfig
if is_torch_available():
@ -36,6 +38,21 @@ if is_torch_available():
logger = logging.get_logger(__name__)
def fuzzy_match_size(config_name: str) -> Optional[str]:
"""
Extract the size digit from strings like "4weight", "8weight".
Returns the digit as an integer if found, otherwise None.
"""
config_name = config_name.lower()
str_match = re.search(r"(\d)weight", config_name)
if str_match:
return str_match.group(1)
return None
# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
@ -121,10 +138,28 @@ class TorchAoHfQuantizer(HfQuantizer):
torch_dtype = torch.float32
return torch_dtype
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
from accelerate.utils import CustomDtype
# Import AOBaseConfig directly since we know we have the right version
if self.quantization_config._get_ao_version() >= version.Version("0.10.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
# Original mapping for non-AOBaseConfig types
map_to_target_dtype = {
"int4_weight_only": CustomDtype.INT4,
"int8_weight_only": torch.int8,
@ -194,14 +229,14 @@ class TorchAoHfQuantizer(HfQuantizer):
from torchao.quantization import quantize_
module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized:
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module)
else:
assert isinstance(self.quantization_config, TorchAoConfig)
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
quantize_(module, self.quantization_config.get_quantize_config())
def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model"""
@ -216,7 +251,7 @@ class TorchAoHfQuantizer(HfQuantizer):
return model
return
def is_serializable(self, safe_serialization=None):
def is_serializable(self, safe_serialization=None) -> bool:
if safe_serialization:
logger.warning(
"torchao quantized model does not support safe serialization, "
@ -237,7 +272,7 @@ class TorchAoHfQuantizer(HfQuantizer):
return _is_torchao_serializable
@property
def is_trainable(self):
def is_trainable(self) -> bool:
supported_quant_types_for_training = [
"int8_weight_only",
"int8_dynamic_activation_int8_weight",

View File

@ -95,6 +95,7 @@ GGUF_MIN_VERSION = "0.10.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0"
HQQ_MIN_VERSION = "0.2.1"
VPTQ_MIN_VERSION = "0.0.4"
TORCHAO_MIN_VERSION = "0.4.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
@ -191,7 +192,7 @@ _tf2onnx_available = _is_package_available("tf2onnx")
_timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio")
_torchao_available = _is_package_available("torchao")
_torchao_available, _torchao_version = _is_package_available("torchao", return_version=True)
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True)
_mlx_available = _is_package_available("mlx")
@ -1277,8 +1278,8 @@ def is_torchaudio_available():
return _torchaudio_available
def is_torchao_available():
return _torchao_available
def is_torchao_available(min_version: str = TORCHAO_MIN_VERSION):
return _torchao_available and version.parse(_torchao_version) >= version.parse(min_version)
def is_speech_available():

View File

@ -1455,11 +1455,18 @@ class HiggsConfig(QuantizationConfigMixin):
@dataclass
class TorchAoConfig(QuantizationConfigMixin):
quant_method: QuantizationMethod
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
modules_to_not_convert: Optional[List]
quant_type_kwargs: Dict[str, Any]
"""This is a config class for torchao quantization/sparsity techniques.
Args:
quant_type (`str`):
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` and `autoquant`.
quant_type (`Union[str, AOBaseConfig]`):
The type of quantization we want to use. Can be either:
- A string: currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
- An AOBaseConfig instance: for more advanced configuration options.
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision.
@ -1471,9 +1478,12 @@ class TorchAoConfig(QuantizationConfigMixin):
Example:
```python
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
# AOBaseConfig-based configuration
config = Int4WeightOnlyConfig(group_size=32)
quantization_config = TorchAoConfig(config)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
# specific quantization method
# String-based configuration
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
@ -1496,105 +1506,152 @@ class TorchAoConfig(QuantizationConfigMixin):
if hasattr(quantized_model, "finalize_autoquant"):
print("finalizing autoquant")
quantized_model.finalize_autoquant()
```
"""
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
def __init__(
self,
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert
# when we load from serailized config, "quant_type_kwargs" will be the key
if "quant_type_kwargs" in kwargs:
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
else:
self.quant_type_kwargs = kwargs
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
self.post_init()
@staticmethod
def _get_ao_version() -> version.Version:
"""Centralized check for TorchAO availability and version requirements."""
if not is_torchao_available():
raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`")
return version.parse(importlib.metadata.version("torchao"))
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if is_torchao_available():
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.7.0"):
raise ValueError("Requires torchao 0.7.0 version and above")
"""Validate configuration and set defaults."""
ao_version = self._get_ao_version()
# Handle quant_type based on type and version
if isinstance(self.quant_type, str):
self._validate_string_quant_type()
elif ao_version >= version.parse("0.10.0"):
from torchao.quantization.quant_api import AOBaseConfig
if not isinstance(self.quant_type, AOBaseConfig):
raise ValueError(
f"quant_type must be either a string or an AOBaseConfig instance, got {type(self.quant_type)}"
)
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
f"In torchao < 0.10.0, quant_type must be a string. Got {type(self.quant_type)}. "
f"Please upgrade to torchao >= 0.10.0 to use AOBaseConfig instances."
)
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
if self.quant_type not in _STR_TO_METHOD.keys():
def _validate_string_quant_type(self):
"""Validate string quant_type and its kwargs."""
methods = self._get_torchao_quant_type_to_method()
if self.quant_type not in methods:
raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
f"Unsupported string quantization type: {self.quant_type}. "
f"Supported types: {', '.join(methods.keys())}"
)
method = _STR_TO_METHOD[self.quant_type]
# Validate kwargs against method signature
method = methods[self.quant_type]
sig = signature(method)
all_kwargs = [
valid_kwargs = {
param.name
for param in sig.parameters.values()
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
]
for k in self.quant_type_kwargs:
if k not in all_kwargs:
raise ValueError(
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
)
}
invalid_kwargs = set(self.quant_type_kwargs) - valid_kwargs
if invalid_kwargs:
raise ValueError(
f"Unexpected keyword arg for {self.quant_type}: {', '.join(invalid_kwargs)}. "
f"Valid kwargs: {', '.join(valid_kwargs)}"
)
def _get_torchao_quant_type_to_method(self):
if is_torchao_available():
from torchao.quantization import (
autoquant,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
"""Get mapping of quant_type strings to their corresponding methods."""
from torchao.quantization import (
autoquant,
int4_weight_only,
int8_dynamic_activation_int8_weight,
int8_weight_only,
)
return {
"int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
"autoquant": autoquant,
}
return {
"int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
"autoquant": autoquant,
}
def get_quantize_config(self):
"""Create the appropriate quantization method based on configuration."""
if isinstance(self.quant_type, str):
methods = self._get_torchao_quant_type_to_method()
quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
):
from torchao.dtypes import Int4CPULayout
quant_type_kwargs["layout"] = Int4CPULayout()
return methods[self.quant_type](**quant_type_kwargs)
else:
raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
return self.quant_type
def get_apply_tensor_subclass(self):
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
):
from torchao.dtypes import Int4CPULayout
quant_type_kwargs["layout"] = Int4CPULayout()
return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs)
def __repr__(self):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout`
dataclasses to simple dicts.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
def to_dict(self):
"""Convert configuration to a dictionary."""
d = super().to_dict()
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
layout = d["quant_type_kwargs"]["layout"]
layout = dataclasses.asdict(layout)
d["quant_type_kwargs"]["layout"] = layout
if isinstance(self.quant_type, str):
# Handle layout serialization if present
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"])
else:
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict
# For now we assume there is 1 config per Transfomer, however in the future
# We may want to support a config per fqn.
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
return d
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
ao_verison = cls._get_ao_version()
assert ao_verison >= version.parse(
"0.10.0"
), "TorchAoConfig requires torchao >= 0.10.0 for construction from dict"
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
# Check if we only have one key which is "default"
# In the future we may update this
assert (
len(quant_type) == 1 and "default" in quant_type
), "Expected only one key 'default' in quant_type dictionary"
quant_type = quant_type["default"]
# Deserialize quant_type if needed
from torchao.core.config import config_from_dict
quant_type = config_from_dict(quant_type)
return cls(quant_type=quant_type, **config_dict)
@dataclass
class BitNetConfig(QuantizationConfigMixin):

View File

@ -85,7 +85,7 @@ class TorchAoConfigTest(unittest.TestCase):
Test kwargs validations in TorchAoConfig
"""
_ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"):
with self.assertRaisesRegex(ValueError, "Unsupported string quantization type"):
_ = TorchAoConfig("fp6")
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
@ -408,5 +408,41 @@ class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
device = "cuda:0"
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
def setUp(self):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Float8WeightOnlyConfig
self.quant_scheme = Float8WeightOnlyConfig()
self.quant_scheme_kwargs = {}
super().setUp()
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
def setUp(self):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
self.quant_scheme_kwargs = {}
super().setUp()
if __name__ == "__main__":
unittest.main()