Add option for ao base configs (#36526)

This commit is contained in:
Driss Guessous 2025-03-19 06:59:47 -07:00 committed by GitHub
parent fef8b7f8e9
commit e8d960329e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 293 additions and 87 deletions

View File

@ -20,18 +20,95 @@ Install torchao with the following command.
pip install --upgrade torch torchao transformers pip install --upgrade torch torchao transformers
``` ```
torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights. torchao supports many quantization types for different data types (int4, float8, weight only, etc.).
Starting with version 0.10.0, torchao provides enhanced flexibility through the `AOBaseConfig` API, allowing for more customized quantization configurations.
And full access to the techniques offered in the torchao library.
You can manually choose the quantization types and settings or automatically select the quantization types. You can manually choose the quantization types and settings or automatically select the quantization types.
<hfoptions id="torchao"> <hfoptions id="torchao">
<hfoption id="manual"> <hfoption id="manual">
Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.
> [!TIP] > [!TIP]
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+. > Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
In torchao 0.10.0+, you can use the more flexible `AOBaseConfig` approach instead of string identifiers:
```py
import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
from torchao.quantization import Int4WeightOnlyConfig
# Using AOBaseConfig instance (torchao >= 0.10.0)
quant_config = Int4WeightOnlyConfig(group_size=128)
quantization_config = TorchAoConfig(quant_type=quant_config)
# Load and quantize the model
quantized_model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Meta-Llama-3-8B",
torch_dtype="auto",
device_map="auto",
quantization_config=quantization_config
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
input_text = "What are we having for dinner?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
print(tokenizer.decode(output[0], skip_special_tokens=True))
```
## Available Quantization Schemes
TorchAO provides a variety of quantization configurations:
- `Int4WeightOnlyConfig`
- `Int8WeightOnlyConfig`
- `Int8DynamicActivationInt8WeightConfig`
- `Float8WeightOnlyConfig`
Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures.
For a complete list of available configurations, see our [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py).
> **⚠️ DEPRECATION WARNING**
>
> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release.
>
> Please use the new `AOBaseConfig`-based approach instead:
>
> ```python
> # Old way (deprecated)
> quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
>
> # New way (recommended)
> from torchao.quantization import Int4WeightOnlyConfig
> quant_config = Int4WeightOnlyConfig(group_size=128)
> quantization_config = TorchAoConfig(quant_type=quant_config)
> ```
>
> The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.
>
> ## Migration Guide
>
> Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents:
>
> | Old String API | New `AOBaseConfig` API |
> |----------------|------------------------|
> | `"int4_weight_only"` | `Int4WeightOnlyConfig()` |
> | `"int8_weight_only"` | `Int8WeightOnlyConfig()` |
> | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` |
>
> All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`).
Below is the API for for torchao < `0.9.0`
```py ```py
import torch import torch
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
@ -78,7 +155,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer. The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer.
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes. Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
> [!TIP] > [!TIP]
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+. > Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
@ -131,7 +208,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
## Serialization ## Serialization
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco. torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao.
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals). To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import importlib import importlib
import re
import types import types
from typing import TYPE_CHECKING, Optional, Union from typing import TYPE_CHECKING, Optional, Union
@ -27,6 +28,7 @@ if TYPE_CHECKING:
from typing import Any, Dict, List from typing import Any, Dict, List
from ..utils import is_torch_available, is_torchao_available, logging from ..utils import is_torch_available, is_torchao_available, logging
from ..utils.quantization_config import TorchAoConfig
if is_torch_available(): if is_torch_available():
@ -36,6 +38,21 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
def fuzzy_match_size(config_name: str) -> Optional[str]:
"""
Extract the size digit from strings like "4weight", "8weight".
Returns the digit as an integer if found, otherwise None.
"""
config_name = config_name.lower()
str_match = re.search(r"(\d)weight", config_name)
if str_match:
return str_match.group(1)
return None
# Finds the parent of a node module named "name" # Finds the parent of a node module named "name"
def find_parent(model, name): def find_parent(model, name):
module_tree = name.split(".")[:-1] module_tree = name.split(".")[:-1]
@ -121,10 +138,28 @@ class TorchAoHfQuantizer(HfQuantizer):
torch_dtype = torch.float32 torch_dtype = torch.float32
return torch_dtype return torch_dtype
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"): if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
from accelerate.utils import CustomDtype from accelerate.utils import CustomDtype
# Import AOBaseConfig directly since we know we have the right version
if self.quantization_config._get_ao_version() >= version.Version("0.10.0"):
from torchao.core.config import AOBaseConfig
quant_type = self.quantization_config.quant_type
if isinstance(quant_type, AOBaseConfig):
# Extract size digit using fuzzy match on the class name
config_name = quant_type.__class__.__name__
size_digit = fuzzy_match_size(config_name)
# Map the extracted digit to appropriate dtype
if size_digit == "4":
return CustomDtype.INT4
else:
# Default to int8
return torch.int8
# Original mapping for non-AOBaseConfig types
map_to_target_dtype = { map_to_target_dtype = {
"int4_weight_only": CustomDtype.INT4, "int4_weight_only": CustomDtype.INT4,
"int8_weight_only": torch.int8, "int8_weight_only": torch.int8,
@ -194,14 +229,14 @@ class TorchAoHfQuantizer(HfQuantizer):
from torchao.quantization import quantize_ from torchao.quantization import quantize_
module, tensor_name = get_module_from_name(model, param_name) module, tensor_name = get_module_from_name(model, param_name)
if self.pre_quantized: if self.pre_quantized:
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
if isinstance(module, nn.Linear): if isinstance(module, nn.Linear):
module.extra_repr = types.MethodType(_linear_extra_repr, module) module.extra_repr = types.MethodType(_linear_extra_repr, module)
else: else:
assert isinstance(self.quantization_config, TorchAoConfig)
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device) module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
quantize_(module, self.quantization_config.get_apply_tensor_subclass()) quantize_(module, self.quantization_config.get_quantize_config())
def _process_model_after_weight_loading(self, model, **kwargs): def _process_model_after_weight_loading(self, model, **kwargs):
"""No process required for torchao quantized model""" """No process required for torchao quantized model"""
@ -216,7 +251,7 @@ class TorchAoHfQuantizer(HfQuantizer):
return model return model
return return
def is_serializable(self, safe_serialization=None): def is_serializable(self, safe_serialization=None) -> bool:
if safe_serialization: if safe_serialization:
logger.warning( logger.warning(
"torchao quantized model does not support safe serialization, " "torchao quantized model does not support safe serialization, "
@ -237,7 +272,7 @@ class TorchAoHfQuantizer(HfQuantizer):
return _is_torchao_serializable return _is_torchao_serializable
@property @property
def is_trainable(self): def is_trainable(self) -> bool:
supported_quant_types_for_training = [ supported_quant_types_for_training = [
"int8_weight_only", "int8_weight_only",
"int8_dynamic_activation_int8_weight", "int8_dynamic_activation_int8_weight",

View File

@ -95,6 +95,7 @@ GGUF_MIN_VERSION = "0.10.0"
XLA_FSDPV2_MIN_VERSION = "2.2.0" XLA_FSDPV2_MIN_VERSION = "2.2.0"
HQQ_MIN_VERSION = "0.2.1" HQQ_MIN_VERSION = "0.2.1"
VPTQ_MIN_VERSION = "0.0.4" VPTQ_MIN_VERSION = "0.0.4"
TORCHAO_MIN_VERSION = "0.4.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
@ -191,7 +192,7 @@ _tf2onnx_available = _is_package_available("tf2onnx")
_timm_available = _is_package_available("timm") _timm_available = _is_package_available("timm")
_tokenizers_available = _is_package_available("tokenizers") _tokenizers_available = _is_package_available("tokenizers")
_torchaudio_available = _is_package_available("torchaudio") _torchaudio_available = _is_package_available("torchaudio")
_torchao_available = _is_package_available("torchao") _torchao_available, _torchao_version = _is_package_available("torchao", return_version=True)
_torchdistx_available = _is_package_available("torchdistx") _torchdistx_available = _is_package_available("torchdistx")
_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True) _torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True)
_mlx_available = _is_package_available("mlx") _mlx_available = _is_package_available("mlx")
@ -1277,8 +1278,8 @@ def is_torchaudio_available():
return _torchaudio_available return _torchaudio_available
def is_torchao_available(): def is_torchao_available(min_version: str = TORCHAO_MIN_VERSION):
return _torchao_available return _torchao_available and version.parse(_torchao_version) >= version.parse(min_version)
def is_speech_available(): def is_speech_available():

View File

@ -1455,11 +1455,18 @@ class HiggsConfig(QuantizationConfigMixin):
@dataclass @dataclass
class TorchAoConfig(QuantizationConfigMixin): class TorchAoConfig(QuantizationConfigMixin):
quant_method: QuantizationMethod
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
modules_to_not_convert: Optional[List]
quant_type_kwargs: Dict[str, Any]
"""This is a config class for torchao quantization/sparsity techniques. """This is a config class for torchao quantization/sparsity techniques.
Args: Args:
quant_type (`str`): quant_type (`Union[str, AOBaseConfig]`):
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` and `autoquant`. The type of quantization we want to use. Can be either:
- A string: currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
- An AOBaseConfig instance: for more advanced configuration options.
modules_to_not_convert (`list`, *optional*, default to `None`): modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision. some modules left in their original precision.
@ -1471,9 +1478,12 @@ class TorchAoConfig(QuantizationConfigMixin):
Example: Example:
```python ```python
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig # AOBaseConfig-based configuration
config = Int4WeightOnlyConfig(group_size=32)
quantization_config = TorchAoConfig(config)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
# specific quantization method # String-based configuration
quantization_config = TorchAoConfig("int4_weight_only", group_size=32) quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now # int4_weight_only quant is only working with *torch.bfloat16* dtype right now
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config) model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
@ -1496,105 +1506,152 @@ class TorchAoConfig(QuantizationConfigMixin):
if hasattr(quantized_model, "finalize_autoquant"): if hasattr(quantized_model, "finalize_autoquant"):
print("finalizing autoquant") print("finalizing autoquant")
quantized_model.finalize_autoquant() quantized_model.finalize_autoquant()
``` ```
""" """
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs): def __init__(
self,
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
modules_to_not_convert: Optional[List] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.TORCHAO self.quant_method = QuantizationMethod.TORCHAO
self.quant_type = quant_type self.quant_type = quant_type
self.modules_to_not_convert = modules_to_not_convert self.modules_to_not_convert = modules_to_not_convert
# when we load from serailized config, "quant_type_kwargs" will be the key self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
if "quant_type_kwargs" in kwargs:
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
else:
self.quant_type_kwargs = kwargs
self.post_init() self.post_init()
@staticmethod
def _get_ao_version() -> version.Version:
"""Centralized check for TorchAO availability and version requirements."""
if not is_torchao_available():
raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`")
return version.parse(importlib.metadata.version("torchao"))
def post_init(self): def post_init(self):
r""" """Validate configuration and set defaults."""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. ao_version = self._get_ao_version()
"""
if is_torchao_available(): # Handle quant_type based on type and version
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.7.0"): if isinstance(self.quant_type, str):
raise ValueError("Requires torchao 0.7.0 version and above") self._validate_string_quant_type()
elif ao_version >= version.parse("0.10.0"):
from torchao.quantization.quant_api import AOBaseConfig
if not isinstance(self.quant_type, AOBaseConfig):
raise ValueError(
f"quant_type must be either a string or an AOBaseConfig instance, got {type(self.quant_type)}"
)
else: else:
raise ValueError( raise ValueError(
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`" f"In torchao < 0.10.0, quant_type must be a string. Got {type(self.quant_type)}. "
f"Please upgrade to torchao >= 0.10.0 to use AOBaseConfig instances."
) )
_STR_TO_METHOD = self._get_torchao_quant_type_to_method() def _validate_string_quant_type(self):
if self.quant_type not in _STR_TO_METHOD.keys(): """Validate string quant_type and its kwargs."""
methods = self._get_torchao_quant_type_to_method()
if self.quant_type not in methods:
raise ValueError( raise ValueError(
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer." f"Unsupported string quantization type: {self.quant_type}. "
f"Supported types: {', '.join(methods.keys())}"
) )
method = _STR_TO_METHOD[self.quant_type] # Validate kwargs against method signature
method = methods[self.quant_type]
sig = signature(method) sig = signature(method)
all_kwargs = [ valid_kwargs = {
param.name param.name
for param in sig.parameters.values() for param in sig.parameters.values()
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD] if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
] }
for k in self.quant_type_kwargs:
if k not in all_kwargs: invalid_kwargs = set(self.quant_type_kwargs) - valid_kwargs
raise ValueError( if invalid_kwargs:
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}" raise ValueError(
) f"Unexpected keyword arg for {self.quant_type}: {', '.join(invalid_kwargs)}. "
f"Valid kwargs: {', '.join(valid_kwargs)}"
)
def _get_torchao_quant_type_to_method(self): def _get_torchao_quant_type_to_method(self):
if is_torchao_available(): """Get mapping of quant_type strings to their corresponding methods."""
from torchao.quantization import ( from torchao.quantization import (
autoquant, autoquant,
int4_weight_only, int4_weight_only,
int8_dynamic_activation_int8_weight, int8_dynamic_activation_int8_weight,
int8_weight_only, int8_weight_only,
) )
return { return {
"int4_weight_only": int4_weight_only, "int4_weight_only": int4_weight_only,
"int8_weight_only": int8_weight_only, "int8_weight_only": int8_weight_only,
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight, "int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
"autoquant": autoquant, "autoquant": autoquant,
} }
def get_quantize_config(self):
"""Create the appropriate quantization method based on configuration."""
if isinstance(self.quant_type, str):
methods = self._get_torchao_quant_type_to_method()
quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
):
from torchao.dtypes import Int4CPULayout
quant_type_kwargs["layout"] = Int4CPULayout()
return methods[self.quant_type](**quant_type_kwargs)
else: else:
raise ValueError( return self.quant_type
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
)
def get_apply_tensor_subclass(self): def to_dict(self):
_STR_TO_METHOD = self._get_torchao_quant_type_to_method() """Convert configuration to a dictionary."""
quant_type_kwargs = self.quant_type_kwargs.copy()
if (
not torch.cuda.is_available()
and is_torchao_available()
and self.quant_type == "int4_weight_only"
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
):
from torchao.dtypes import Int4CPULayout
quant_type_kwargs["layout"] = Int4CPULayout()
return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs)
def __repr__(self):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout`
dataclasses to simple dicts.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
d = super().to_dict() d = super().to_dict()
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
layout = d["quant_type_kwargs"]["layout"] if isinstance(self.quant_type, str):
layout = dataclasses.asdict(layout) # Handle layout serialization if present
d["quant_type_kwargs"]["layout"] = layout if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"])
else:
# Handle AOBaseConfig serialization
from torchao.core.config import config_to_dict
# For now we assume there is 1 config per Transfomer, however in the future
# We may want to support a config per fqn.
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
return d return d
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
"""Create configuration from a dictionary."""
ao_verison = cls._get_ao_version()
assert ao_verison >= version.parse(
"0.10.0"
), "TorchAoConfig requires torchao >= 0.10.0 for construction from dict"
config_dict = config_dict.copy()
quant_type = config_dict.pop("quant_type")
# Check if we only have one key which is "default"
# In the future we may update this
assert (
len(quant_type) == 1 and "default" in quant_type
), "Expected only one key 'default' in quant_type dictionary"
quant_type = quant_type["default"]
# Deserialize quant_type if needed
from torchao.core.config import config_from_dict
quant_type = config_from_dict(quant_type)
return cls(quant_type=quant_type, **config_dict)
@dataclass @dataclass
class BitNetConfig(QuantizationConfigMixin): class BitNetConfig(QuantizationConfigMixin):

View File

@ -85,7 +85,7 @@ class TorchAoConfigTest(unittest.TestCase):
Test kwargs validations in TorchAoConfig Test kwargs validations in TorchAoConfig
""" """
_ = TorchAoConfig("int4_weight_only") _ = TorchAoConfig("int4_weight_only")
with self.assertRaisesRegex(ValueError, "is not supported yet"): with self.assertRaisesRegex(ValueError, "Unsupported string quantization type"):
_ = TorchAoConfig("fp6") _ = TorchAoConfig("fp6")
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"): with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
@ -408,5 +408,41 @@ class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
device = "cuda:0" device = "cuda:0"
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
def setUp(self):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Float8WeightOnlyConfig
self.quant_scheme = Float8WeightOnlyConfig()
self.quant_scheme_kwargs = {}
super().setUp()
@require_torch_gpu
@require_torchao_version_greater_or_equal("0.10.0")
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
device = "cuda:0"
def setUp(self):
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
self.quant_scheme_kwargs = {}
super().setUp()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()