mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add option for ao base configs (#36526)
This commit is contained in:
parent
fef8b7f8e9
commit
e8d960329e
@ -20,18 +20,95 @@ Install torchao with the following command.
|
||||
pip install --upgrade torch torchao transformers
|
||||
```
|
||||
|
||||
torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights.
|
||||
torchao supports many quantization types for different data types (int4, float8, weight only, etc.).
|
||||
Starting with version 0.10.0, torchao provides enhanced flexibility through the `AOBaseConfig` API, allowing for more customized quantization configurations.
|
||||
And full access to the techniques offered in the torchao library.
|
||||
|
||||
You can manually choose the quantization types and settings or automatically select the quantization types.
|
||||
|
||||
<hfoptions id="torchao">
|
||||
<hfoption id="manual">
|
||||
|
||||
|
||||
Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.
|
||||
|
||||
> [!TIP]
|
||||
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
|
||||
|
||||
In torchao 0.10.0+, you can use the more flexible `AOBaseConfig` approach instead of string identifiers:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
from torchao.quantization import Int4WeightOnlyConfig
|
||||
|
||||
# Using AOBaseConfig instance (torchao >= 0.10.0)
|
||||
quant_config = Int4WeightOnlyConfig(group_size=128)
|
||||
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
|
||||
# Load and quantize the model
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
"meta-llama/Meta-Llama-3-8B",
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
quantization_config=quantization_config
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
|
||||
input_text = "What are we having for dinner?"
|
||||
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||
|
||||
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||
```
|
||||
|
||||
## Available Quantization Schemes
|
||||
|
||||
TorchAO provides a variety of quantization configurations:
|
||||
|
||||
- `Int4WeightOnlyConfig`
|
||||
- `Int8WeightOnlyConfig`
|
||||
- `Int8DynamicActivationInt8WeightConfig`
|
||||
- `Float8WeightOnlyConfig`
|
||||
|
||||
Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures.
|
||||
|
||||
For a complete list of available configurations, see our [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py).
|
||||
|
||||
> **⚠️ DEPRECATION WARNING**
|
||||
>
|
||||
> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release.
|
||||
>
|
||||
> Please use the new `AOBaseConfig`-based approach instead:
|
||||
>
|
||||
> ```python
|
||||
> # Old way (deprecated)
|
||||
> quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||
>
|
||||
> # New way (recommended)
|
||||
> from torchao.quantization import Int4WeightOnlyConfig
|
||||
> quant_config = Int4WeightOnlyConfig(group_size=128)
|
||||
> quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||
> ```
|
||||
>
|
||||
> The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.
|
||||
>
|
||||
> ## Migration Guide
|
||||
>
|
||||
> Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents:
|
||||
>
|
||||
> | Old String API | New `AOBaseConfig` API |
|
||||
> |----------------|------------------------|
|
||||
> | `"int4_weight_only"` | `Int4WeightOnlyConfig()` |
|
||||
> | `"int8_weight_only"` | `Int8WeightOnlyConfig()` |
|
||||
> | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` |
|
||||
>
|
||||
> All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`).
|
||||
|
||||
|
||||
Below is the API for for torchao < `0.9.0`
|
||||
|
||||
```py
|
||||
import torch
|
||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||
@ -78,7 +155,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
|
||||
|
||||
The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer.
|
||||
|
||||
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
|
||||
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
|
||||
|
||||
> [!TIP]
|
||||
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
|
||||
@ -131,7 +208,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
|
||||
|
||||
## Serialization
|
||||
|
||||
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco.
|
||||
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao.
|
||||
|
||||
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).
|
||||
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import re
|
||||
import types
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
|
||||
@ -27,6 +28,7 @@ if TYPE_CHECKING:
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from ..utils import is_torch_available, is_torchao_available, logging
|
||||
from ..utils.quantization_config import TorchAoConfig
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -36,6 +38,21 @@ if is_torch_available():
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def fuzzy_match_size(config_name: str) -> Optional[str]:
|
||||
"""
|
||||
Extract the size digit from strings like "4weight", "8weight".
|
||||
Returns the digit as an integer if found, otherwise None.
|
||||
"""
|
||||
config_name = config_name.lower()
|
||||
|
||||
str_match = re.search(r"(\d)weight", config_name)
|
||||
|
||||
if str_match:
|
||||
return str_match.group(1)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
# Finds the parent of a node module named "name"
|
||||
def find_parent(model, name):
|
||||
module_tree = name.split(".")[:-1]
|
||||
@ -121,10 +138,28 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
torch_dtype = torch.float32
|
||||
return torch_dtype
|
||||
|
||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
||||
def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
|
||||
from accelerate.utils import CustomDtype
|
||||
|
||||
# Import AOBaseConfig directly since we know we have the right version
|
||||
if self.quantization_config._get_ao_version() >= version.Version("0.10.0"):
|
||||
from torchao.core.config import AOBaseConfig
|
||||
|
||||
quant_type = self.quantization_config.quant_type
|
||||
if isinstance(quant_type, AOBaseConfig):
|
||||
# Extract size digit using fuzzy match on the class name
|
||||
config_name = quant_type.__class__.__name__
|
||||
size_digit = fuzzy_match_size(config_name)
|
||||
|
||||
# Map the extracted digit to appropriate dtype
|
||||
if size_digit == "4":
|
||||
return CustomDtype.INT4
|
||||
else:
|
||||
# Default to int8
|
||||
return torch.int8
|
||||
|
||||
# Original mapping for non-AOBaseConfig types
|
||||
map_to_target_dtype = {
|
||||
"int4_weight_only": CustomDtype.INT4,
|
||||
"int8_weight_only": torch.int8,
|
||||
@ -194,14 +229,14 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
from torchao.quantization import quantize_
|
||||
|
||||
module, tensor_name = get_module_from_name(model, param_name)
|
||||
|
||||
if self.pre_quantized:
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
|
||||
if isinstance(module, nn.Linear):
|
||||
module.extra_repr = types.MethodType(_linear_extra_repr, module)
|
||||
else:
|
||||
assert isinstance(self.quantization_config, TorchAoConfig)
|
||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
||||
quantize_(module, self.quantization_config.get_quantize_config())
|
||||
|
||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||
"""No process required for torchao quantized model"""
|
||||
@ -216,7 +251,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
return model
|
||||
return
|
||||
|
||||
def is_serializable(self, safe_serialization=None):
|
||||
def is_serializable(self, safe_serialization=None) -> bool:
|
||||
if safe_serialization:
|
||||
logger.warning(
|
||||
"torchao quantized model does not support safe serialization, "
|
||||
@ -237,7 +272,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
return _is_torchao_serializable
|
||||
|
||||
@property
|
||||
def is_trainable(self):
|
||||
def is_trainable(self) -> bool:
|
||||
supported_quant_types_for_training = [
|
||||
"int8_weight_only",
|
||||
"int8_dynamic_activation_int8_weight",
|
||||
|
@ -95,6 +95,7 @@ GGUF_MIN_VERSION = "0.10.0"
|
||||
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
||||
HQQ_MIN_VERSION = "0.2.1"
|
||||
VPTQ_MIN_VERSION = "0.0.4"
|
||||
TORCHAO_MIN_VERSION = "0.4.0"
|
||||
|
||||
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
@ -191,7 +192,7 @@ _tf2onnx_available = _is_package_available("tf2onnx")
|
||||
_timm_available = _is_package_available("timm")
|
||||
_tokenizers_available = _is_package_available("tokenizers")
|
||||
_torchaudio_available = _is_package_available("torchaudio")
|
||||
_torchao_available = _is_package_available("torchao")
|
||||
_torchao_available, _torchao_version = _is_package_available("torchao", return_version=True)
|
||||
_torchdistx_available = _is_package_available("torchdistx")
|
||||
_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True)
|
||||
_mlx_available = _is_package_available("mlx")
|
||||
@ -1277,8 +1278,8 @@ def is_torchaudio_available():
|
||||
return _torchaudio_available
|
||||
|
||||
|
||||
def is_torchao_available():
|
||||
return _torchao_available
|
||||
def is_torchao_available(min_version: str = TORCHAO_MIN_VERSION):
|
||||
return _torchao_available and version.parse(_torchao_version) >= version.parse(min_version)
|
||||
|
||||
|
||||
def is_speech_available():
|
||||
|
@ -1455,11 +1455,18 @@ class HiggsConfig(QuantizationConfigMixin):
|
||||
|
||||
@dataclass
|
||||
class TorchAoConfig(QuantizationConfigMixin):
|
||||
quant_method: QuantizationMethod
|
||||
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
|
||||
modules_to_not_convert: Optional[List]
|
||||
quant_type_kwargs: Dict[str, Any]
|
||||
|
||||
"""This is a config class for torchao quantization/sparsity techniques.
|
||||
|
||||
Args:
|
||||
quant_type (`str`):
|
||||
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` and `autoquant`.
|
||||
quant_type (`Union[str, AOBaseConfig]`):
|
||||
The type of quantization we want to use. Can be either:
|
||||
- A string: currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
|
||||
- An AOBaseConfig instance: for more advanced configuration options.
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||
some modules left in their original precision.
|
||||
@ -1471,9 +1478,12 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
Example:
|
||||
|
||||
```python
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
||||
# AOBaseConfig-based configuration
|
||||
config = Int4WeightOnlyConfig(group_size=32)
|
||||
quantization_config = TorchAoConfig(config)
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||
|
||||
# specific quantization method
|
||||
# String-based configuration
|
||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||
@ -1496,105 +1506,152 @@ class TorchAoConfig(QuantizationConfigMixin):
|
||||
if hasattr(quantized_model, "finalize_autoquant"):
|
||||
print("finalizing autoquant")
|
||||
quantized_model.finalize_autoquant()
|
||||
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
|
||||
def __init__(
|
||||
self,
|
||||
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.TORCHAO
|
||||
self.quant_type = quant_type
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
# when we load from serailized config, "quant_type_kwargs" will be the key
|
||||
if "quant_type_kwargs" in kwargs:
|
||||
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
|
||||
else:
|
||||
self.quant_type_kwargs = kwargs
|
||||
|
||||
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
|
||||
self.post_init()
|
||||
|
||||
@staticmethod
|
||||
def _get_ao_version() -> version.Version:
|
||||
"""Centralized check for TorchAO availability and version requirements."""
|
||||
if not is_torchao_available():
|
||||
raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`")
|
||||
|
||||
return version.parse(importlib.metadata.version("torchao"))
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
"""
|
||||
if is_torchao_available():
|
||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.7.0"):
|
||||
raise ValueError("Requires torchao 0.7.0 version and above")
|
||||
"""Validate configuration and set defaults."""
|
||||
ao_version = self._get_ao_version()
|
||||
|
||||
# Handle quant_type based on type and version
|
||||
if isinstance(self.quant_type, str):
|
||||
self._validate_string_quant_type()
|
||||
elif ao_version >= version.parse("0.10.0"):
|
||||
from torchao.quantization.quant_api import AOBaseConfig
|
||||
|
||||
if not isinstance(self.quant_type, AOBaseConfig):
|
||||
raise ValueError(
|
||||
f"quant_type must be either a string or an AOBaseConfig instance, got {type(self.quant_type)}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
f"In torchao < 0.10.0, quant_type must be a string. Got {type(self.quant_type)}. "
|
||||
f"Please upgrade to torchao >= 0.10.0 to use AOBaseConfig instances."
|
||||
)
|
||||
|
||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
||||
if self.quant_type not in _STR_TO_METHOD.keys():
|
||||
def _validate_string_quant_type(self):
|
||||
"""Validate string quant_type and its kwargs."""
|
||||
methods = self._get_torchao_quant_type_to_method()
|
||||
|
||||
if self.quant_type not in methods:
|
||||
raise ValueError(
|
||||
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
|
||||
f"Unsupported string quantization type: {self.quant_type}. "
|
||||
f"Supported types: {', '.join(methods.keys())}"
|
||||
)
|
||||
|
||||
method = _STR_TO_METHOD[self.quant_type]
|
||||
# Validate kwargs against method signature
|
||||
method = methods[self.quant_type]
|
||||
sig = signature(method)
|
||||
all_kwargs = [
|
||||
valid_kwargs = {
|
||||
param.name
|
||||
for param in sig.parameters.values()
|
||||
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
|
||||
]
|
||||
for k in self.quant_type_kwargs:
|
||||
if k not in all_kwargs:
|
||||
raise ValueError(
|
||||
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
|
||||
)
|
||||
}
|
||||
|
||||
invalid_kwargs = set(self.quant_type_kwargs) - valid_kwargs
|
||||
if invalid_kwargs:
|
||||
raise ValueError(
|
||||
f"Unexpected keyword arg for {self.quant_type}: {', '.join(invalid_kwargs)}. "
|
||||
f"Valid kwargs: {', '.join(valid_kwargs)}"
|
||||
)
|
||||
|
||||
def _get_torchao_quant_type_to_method(self):
|
||||
if is_torchao_available():
|
||||
from torchao.quantization import (
|
||||
autoquant,
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
)
|
||||
"""Get mapping of quant_type strings to their corresponding methods."""
|
||||
from torchao.quantization import (
|
||||
autoquant,
|
||||
int4_weight_only,
|
||||
int8_dynamic_activation_int8_weight,
|
||||
int8_weight_only,
|
||||
)
|
||||
|
||||
return {
|
||||
"int4_weight_only": int4_weight_only,
|
||||
"int8_weight_only": int8_weight_only,
|
||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||
"autoquant": autoquant,
|
||||
}
|
||||
return {
|
||||
"int4_weight_only": int4_weight_only,
|
||||
"int8_weight_only": int8_weight_only,
|
||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||
"autoquant": autoquant,
|
||||
}
|
||||
|
||||
def get_quantize_config(self):
|
||||
"""Create the appropriate quantization method based on configuration."""
|
||||
if isinstance(self.quant_type, str):
|
||||
methods = self._get_torchao_quant_type_to_method()
|
||||
quant_type_kwargs = self.quant_type_kwargs.copy()
|
||||
if (
|
||||
not torch.cuda.is_available()
|
||||
and is_torchao_available()
|
||||
and self.quant_type == "int4_weight_only"
|
||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
):
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
quant_type_kwargs["layout"] = Int4CPULayout()
|
||||
|
||||
return methods[self.quant_type](**quant_type_kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
||||
)
|
||||
return self.quant_type
|
||||
|
||||
def get_apply_tensor_subclass(self):
|
||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
||||
quant_type_kwargs = self.quant_type_kwargs.copy()
|
||||
if (
|
||||
not torch.cuda.is_available()
|
||||
and is_torchao_available()
|
||||
and self.quant_type == "int4_weight_only"
|
||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||
):
|
||||
from torchao.dtypes import Int4CPULayout
|
||||
|
||||
quant_type_kwargs["layout"] = Int4CPULayout()
|
||||
return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs)
|
||||
|
||||
def __repr__(self):
|
||||
config_dict = self.to_dict()
|
||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout`
|
||||
dataclasses to simple dicts.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
def to_dict(self):
|
||||
"""Convert configuration to a dictionary."""
|
||||
d = super().to_dict()
|
||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||
layout = d["quant_type_kwargs"]["layout"]
|
||||
layout = dataclasses.asdict(layout)
|
||||
d["quant_type_kwargs"]["layout"] = layout
|
||||
|
||||
if isinstance(self.quant_type, str):
|
||||
# Handle layout serialization if present
|
||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"])
|
||||
else:
|
||||
# Handle AOBaseConfig serialization
|
||||
from torchao.core.config import config_to_dict
|
||||
|
||||
# For now we assume there is 1 config per Transfomer, however in the future
|
||||
# We may want to support a config per fqn.
|
||||
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
|
||||
|
||||
return d
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||
"""Create configuration from a dictionary."""
|
||||
ao_verison = cls._get_ao_version()
|
||||
assert ao_verison >= version.parse(
|
||||
"0.10.0"
|
||||
), "TorchAoConfig requires torchao >= 0.10.0 for construction from dict"
|
||||
config_dict = config_dict.copy()
|
||||
quant_type = config_dict.pop("quant_type")
|
||||
# Check if we only have one key which is "default"
|
||||
# In the future we may update this
|
||||
assert (
|
||||
len(quant_type) == 1 and "default" in quant_type
|
||||
), "Expected only one key 'default' in quant_type dictionary"
|
||||
quant_type = quant_type["default"]
|
||||
|
||||
# Deserialize quant_type if needed
|
||||
from torchao.core.config import config_from_dict
|
||||
|
||||
quant_type = config_from_dict(quant_type)
|
||||
|
||||
return cls(quant_type=quant_type, **config_dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BitNetConfig(QuantizationConfigMixin):
|
||||
|
@ -85,7 +85,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
||||
Test kwargs validations in TorchAoConfig
|
||||
"""
|
||||
_ = TorchAoConfig("int4_weight_only")
|
||||
with self.assertRaisesRegex(ValueError, "is not supported yet"):
|
||||
with self.assertRaisesRegex(ValueError, "Unsupported string quantization type"):
|
||||
_ = TorchAoConfig("fp6")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
|
||||
@ -408,5 +408,41 @@ class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
||||
device = "cuda:0"
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.10.0")
|
||||
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
def setUp(self):
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||
|
||||
from torchao.quantization import Float8WeightOnlyConfig
|
||||
|
||||
self.quant_scheme = Float8WeightOnlyConfig()
|
||||
self.quant_scheme_kwargs = {}
|
||||
super().setUp()
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@require_torchao_version_greater_or_equal("0.10.0")
|
||||
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
||||
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||
device = "cuda:0"
|
||||
|
||||
def setUp(self):
|
||||
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||
|
||||
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
||||
|
||||
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
|
||||
self.quant_scheme_kwargs = {}
|
||||
super().setUp()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user