mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-01 18:51:14 +06:00
Add option for ao base configs (#36526)
This commit is contained in:
parent
fef8b7f8e9
commit
e8d960329e
@ -20,18 +20,95 @@ Install torchao with the following command.
|
|||||||
pip install --upgrade torch torchao transformers
|
pip install --upgrade torch torchao transformers
|
||||||
```
|
```
|
||||||
|
|
||||||
torchao supports many quantization types for different data types (int4, float8, weight only, etc.), but the Transformers integration only currently supports int8 weight quantization and int8 dynamic quantization of weights.
|
torchao supports many quantization types for different data types (int4, float8, weight only, etc.).
|
||||||
|
Starting with version 0.10.0, torchao provides enhanced flexibility through the `AOBaseConfig` API, allowing for more customized quantization configurations.
|
||||||
|
And full access to the techniques offered in the torchao library.
|
||||||
|
|
||||||
You can manually choose the quantization types and settings or automatically select the quantization types.
|
You can manually choose the quantization types and settings or automatically select the quantization types.
|
||||||
|
|
||||||
<hfoptions id="torchao">
|
<hfoptions id="torchao">
|
||||||
<hfoption id="manual">
|
<hfoption id="manual">
|
||||||
|
|
||||||
|
|
||||||
Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.
|
Create a [`TorchAoConfig`] and specify the quantization type and `group_size` of the weights to quantize. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
|
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
|
||||||
|
|
||||||
|
In torchao 0.10.0+, you can use the more flexible `AOBaseConfig` approach instead of string identifiers:
|
||||||
|
|
||||||
|
```py
|
||||||
|
import torch
|
||||||
|
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
|
from torchao.quantization import Int4WeightOnlyConfig
|
||||||
|
|
||||||
|
# Using AOBaseConfig instance (torchao >= 0.10.0)
|
||||||
|
quant_config = Int4WeightOnlyConfig(group_size=128)
|
||||||
|
quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||||
|
|
||||||
|
# Load and quantize the model
|
||||||
|
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
"meta-llama/Meta-Llama-3-8B",
|
||||||
|
torch_dtype="auto",
|
||||||
|
device_map="auto",
|
||||||
|
quantization_config=quantization_config
|
||||||
|
)
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
|
||||||
|
input_text = "What are we having for dinner?"
|
||||||
|
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")
|
||||||
|
|
||||||
|
# auto-compile the quantized model with `cache_implementation="static"` to get speed up
|
||||||
|
output = quantized_model.generate(**input_ids, max_new_tokens=10, cache_implementation="static")
|
||||||
|
print(tokenizer.decode(output[0], skip_special_tokens=True))
|
||||||
|
```
|
||||||
|
|
||||||
|
## Available Quantization Schemes
|
||||||
|
|
||||||
|
TorchAO provides a variety of quantization configurations:
|
||||||
|
|
||||||
|
- `Int4WeightOnlyConfig`
|
||||||
|
- `Int8WeightOnlyConfig`
|
||||||
|
- `Int8DynamicActivationInt8WeightConfig`
|
||||||
|
- `Float8WeightOnlyConfig`
|
||||||
|
|
||||||
|
Each configuration can be further customized with parameters such as `group_size`, `scheme`, and `layout` to optimize for specific hardware and model architectures.
|
||||||
|
|
||||||
|
For a complete list of available configurations, see our [quantization API documentation](https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py).
|
||||||
|
|
||||||
|
> **⚠️ DEPRECATION WARNING**
|
||||||
|
>
|
||||||
|
> Starting with version 0.10.0, the string-based API for quantization configuration (e.g., `TorchAoConfig("int4_weight_only", group_size=128)`) is **deprecated** and will be removed in a future release.
|
||||||
|
>
|
||||||
|
> Please use the new `AOBaseConfig`-based approach instead:
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> # Old way (deprecated)
|
||||||
|
> quantization_config = TorchAoConfig("int4_weight_only", group_size=128)
|
||||||
|
>
|
||||||
|
> # New way (recommended)
|
||||||
|
> from torchao.quantization import Int4WeightOnlyConfig
|
||||||
|
> quant_config = Int4WeightOnlyConfig(group_size=128)
|
||||||
|
> quantization_config = TorchAoConfig(quant_type=quant_config)
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> The new API offers greater flexibility, better type safety, and access to the full range of features available in torchao.
|
||||||
|
>
|
||||||
|
> ## Migration Guide
|
||||||
|
>
|
||||||
|
> Here's how to migrate from common string identifiers to their `AOBaseConfig` equivalents:
|
||||||
|
>
|
||||||
|
> | Old String API | New `AOBaseConfig` API |
|
||||||
|
> |----------------|------------------------|
|
||||||
|
> | `"int4_weight_only"` | `Int4WeightOnlyConfig()` |
|
||||||
|
> | `"int8_weight_only"` | `Int8WeightOnlyConfig()` |
|
||||||
|
> | `"int8_dynamic_activation_int8_weight"` | `Int8DynamicActivationInt8WeightConfig()` |
|
||||||
|
>
|
||||||
|
> All configuration objects accept parameters for customization (e.g., `group_size`, `scheme`, `layout`).
|
||||||
|
|
||||||
|
|
||||||
|
Below is the API for for torchao < `0.9.0`
|
||||||
|
|
||||||
```py
|
```py
|
||||||
import torch
|
import torch
|
||||||
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
from transformers import TorchAoConfig, AutoModelForCausalLM, AutoTokenizer
|
||||||
@ -78,7 +155,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
|
|||||||
|
|
||||||
The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer.
|
The [autoquant](https://pytorch.org/ao/stable/generated/torchao.quantization.autoquant.html#torchao.quantization.autoquant) API automatically chooses a quantization type for quantizable layers (`nn.Linear`) by micro-benchmarking on input type and shape and compiling a single linear layer.
|
||||||
|
|
||||||
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
|
Create a [`TorchAoConfig`] and set to `"autoquant"`. Set the `cache_implementation` to `"static"` to automatically [torch.compile](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html) the forward method. Finally, call `finalize_autoquant` on the quantized model to finalize the quantization and log the input shapes.
|
||||||
|
|
||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
|
> Run the quantized model on a CPU by changing `device_map` to `"cpu"` and `layout` to `Int4CPULayout()`. This is only available in torchao 0.8.0+.
|
||||||
@ -131,7 +208,7 @@ print("bf16 model:", benchmark_fn(bf16_model.generate, **input_ids, max_new_toke
|
|||||||
|
|
||||||
## Serialization
|
## Serialization
|
||||||
|
|
||||||
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchaco.
|
torchao implements [torch.Tensor subclasses](https://pytorch.org/docs/stable/notes/extending.html#subclassing-torch-tensor) for maximum flexibility in supporting new quantized torch.Tensor formats. [Safetensors](https://huggingface.co/docs/safetensors/en/index) serialization and deserialization does not work with torchao.
|
||||||
|
|
||||||
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).
|
To avoid arbitrary user code execution, torchao sets `weights_only=True` in [torch.load](https://pytorch.org/docs/stable/generated/torch.load.html) to ensure only tensors are loaded. Any known user functions can be whitelisted with [add_safe_globals](https://pytorch.org/docs/stable/notes/serialization.html#torch.serialization.add_safe_globals).
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import importlib
|
import importlib
|
||||||
|
import re
|
||||||
import types
|
import types
|
||||||
from typing import TYPE_CHECKING, Optional, Union
|
from typing import TYPE_CHECKING, Optional, Union
|
||||||
|
|
||||||
@ -27,6 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
from ..utils import is_torch_available, is_torchao_available, logging
|
from ..utils import is_torch_available, is_torchao_available, logging
|
||||||
|
from ..utils.quantization_config import TorchAoConfig
|
||||||
|
|
||||||
|
|
||||||
if is_torch_available():
|
if is_torch_available():
|
||||||
@ -36,6 +38,21 @@ if is_torch_available():
|
|||||||
logger = logging.get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def fuzzy_match_size(config_name: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Extract the size digit from strings like "4weight", "8weight".
|
||||||
|
Returns the digit as an integer if found, otherwise None.
|
||||||
|
"""
|
||||||
|
config_name = config_name.lower()
|
||||||
|
|
||||||
|
str_match = re.search(r"(\d)weight", config_name)
|
||||||
|
|
||||||
|
if str_match:
|
||||||
|
return str_match.group(1)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
# Finds the parent of a node module named "name"
|
# Finds the parent of a node module named "name"
|
||||||
def find_parent(model, name):
|
def find_parent(model, name):
|
||||||
module_tree = name.split(".")[:-1]
|
module_tree = name.split(".")[:-1]
|
||||||
@ -121,10 +138,28 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
torch_dtype = torch.float32
|
torch_dtype = torch.float32
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
def adjust_target_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||||
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
|
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
|
||||||
from accelerate.utils import CustomDtype
|
from accelerate.utils import CustomDtype
|
||||||
|
|
||||||
|
# Import AOBaseConfig directly since we know we have the right version
|
||||||
|
if self.quantization_config._get_ao_version() >= version.Version("0.10.0"):
|
||||||
|
from torchao.core.config import AOBaseConfig
|
||||||
|
|
||||||
|
quant_type = self.quantization_config.quant_type
|
||||||
|
if isinstance(quant_type, AOBaseConfig):
|
||||||
|
# Extract size digit using fuzzy match on the class name
|
||||||
|
config_name = quant_type.__class__.__name__
|
||||||
|
size_digit = fuzzy_match_size(config_name)
|
||||||
|
|
||||||
|
# Map the extracted digit to appropriate dtype
|
||||||
|
if size_digit == "4":
|
||||||
|
return CustomDtype.INT4
|
||||||
|
else:
|
||||||
|
# Default to int8
|
||||||
|
return torch.int8
|
||||||
|
|
||||||
|
# Original mapping for non-AOBaseConfig types
|
||||||
map_to_target_dtype = {
|
map_to_target_dtype = {
|
||||||
"int4_weight_only": CustomDtype.INT4,
|
"int4_weight_only": CustomDtype.INT4,
|
||||||
"int8_weight_only": torch.int8,
|
"int8_weight_only": torch.int8,
|
||||||
@ -194,14 +229,14 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
from torchao.quantization import quantize_
|
from torchao.quantization import quantize_
|
||||||
|
|
||||||
module, tensor_name = get_module_from_name(model, param_name)
|
module, tensor_name = get_module_from_name(model, param_name)
|
||||||
|
|
||||||
if self.pre_quantized:
|
if self.pre_quantized:
|
||||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
|
module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device))
|
||||||
if isinstance(module, nn.Linear):
|
if isinstance(module, nn.Linear):
|
||||||
module.extra_repr = types.MethodType(_linear_extra_repr, module)
|
module.extra_repr = types.MethodType(_linear_extra_repr, module)
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(self.quantization_config, TorchAoConfig)
|
||||||
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
module._parameters[tensor_name] = torch.nn.Parameter(param_value).to(device=target_device)
|
||||||
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
|
quantize_(module, self.quantization_config.get_quantize_config())
|
||||||
|
|
||||||
def _process_model_after_weight_loading(self, model, **kwargs):
|
def _process_model_after_weight_loading(self, model, **kwargs):
|
||||||
"""No process required for torchao quantized model"""
|
"""No process required for torchao quantized model"""
|
||||||
@ -216,7 +251,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
return model
|
return model
|
||||||
return
|
return
|
||||||
|
|
||||||
def is_serializable(self, safe_serialization=None):
|
def is_serializable(self, safe_serialization=None) -> bool:
|
||||||
if safe_serialization:
|
if safe_serialization:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"torchao quantized model does not support safe serialization, "
|
"torchao quantized model does not support safe serialization, "
|
||||||
@ -237,7 +272,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
|||||||
return _is_torchao_serializable
|
return _is_torchao_serializable
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_trainable(self):
|
def is_trainable(self) -> bool:
|
||||||
supported_quant_types_for_training = [
|
supported_quant_types_for_training = [
|
||||||
"int8_weight_only",
|
"int8_weight_only",
|
||||||
"int8_dynamic_activation_int8_weight",
|
"int8_dynamic_activation_int8_weight",
|
||||||
|
@ -95,6 +95,7 @@ GGUF_MIN_VERSION = "0.10.0"
|
|||||||
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
XLA_FSDPV2_MIN_VERSION = "2.2.0"
|
||||||
HQQ_MIN_VERSION = "0.2.1"
|
HQQ_MIN_VERSION = "0.2.1"
|
||||||
VPTQ_MIN_VERSION = "0.0.4"
|
VPTQ_MIN_VERSION = "0.0.4"
|
||||||
|
TORCHAO_MIN_VERSION = "0.4.0"
|
||||||
|
|
||||||
|
|
||||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||||
@ -191,7 +192,7 @@ _tf2onnx_available = _is_package_available("tf2onnx")
|
|||||||
_timm_available = _is_package_available("timm")
|
_timm_available = _is_package_available("timm")
|
||||||
_tokenizers_available = _is_package_available("tokenizers")
|
_tokenizers_available = _is_package_available("tokenizers")
|
||||||
_torchaudio_available = _is_package_available("torchaudio")
|
_torchaudio_available = _is_package_available("torchaudio")
|
||||||
_torchao_available = _is_package_available("torchao")
|
_torchao_available, _torchao_version = _is_package_available("torchao", return_version=True)
|
||||||
_torchdistx_available = _is_package_available("torchdistx")
|
_torchdistx_available = _is_package_available("torchdistx")
|
||||||
_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True)
|
_torchvision_available, _torchvision_version = _is_package_available("torchvision", return_version=True)
|
||||||
_mlx_available = _is_package_available("mlx")
|
_mlx_available = _is_package_available("mlx")
|
||||||
@ -1277,8 +1278,8 @@ def is_torchaudio_available():
|
|||||||
return _torchaudio_available
|
return _torchaudio_available
|
||||||
|
|
||||||
|
|
||||||
def is_torchao_available():
|
def is_torchao_available(min_version: str = TORCHAO_MIN_VERSION):
|
||||||
return _torchao_available
|
return _torchao_available and version.parse(_torchao_version) >= version.parse(min_version)
|
||||||
|
|
||||||
|
|
||||||
def is_speech_available():
|
def is_speech_available():
|
||||||
|
@ -1455,11 +1455,18 @@ class HiggsConfig(QuantizationConfigMixin):
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TorchAoConfig(QuantizationConfigMixin):
|
class TorchAoConfig(QuantizationConfigMixin):
|
||||||
|
quant_method: QuantizationMethod
|
||||||
|
quant_type: Union[str, "AOBaseConfig"] # noqa: F821
|
||||||
|
modules_to_not_convert: Optional[List]
|
||||||
|
quant_type_kwargs: Dict[str, Any]
|
||||||
|
|
||||||
"""This is a config class for torchao quantization/sparsity techniques.
|
"""This is a config class for torchao quantization/sparsity techniques.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
quant_type (`str`):
|
quant_type (`Union[str, AOBaseConfig]`):
|
||||||
The type of quantization we want to use, currently supporting: `int4_weight_only`, `int8_weight_only`, `int8_dynamic_activation_int8_weight` and `autoquant`.
|
The type of quantization we want to use. Can be either:
|
||||||
|
- A string: currently supporting: `int4_weight_only`, `int8_weight_only` and `int8_dynamic_activation_int8_weight`.
|
||||||
|
- An AOBaseConfig instance: for more advanced configuration options.
|
||||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||||
some modules left in their original precision.
|
some modules left in their original precision.
|
||||||
@ -1471,9 +1478,12 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
Example:
|
Example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
|
# AOBaseConfig-based configuration
|
||||||
|
config = Int4WeightOnlyConfig(group_size=32)
|
||||||
|
quantization_config = TorchAoConfig(config)
|
||||||
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||||
|
|
||||||
# specific quantization method
|
# String-based configuration
|
||||||
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
|
quantization_config = TorchAoConfig("int4_weight_only", group_size=32)
|
||||||
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
|
# int4_weight_only quant is only working with *torch.bfloat16* dtype right now
|
||||||
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
|
||||||
@ -1496,105 +1506,152 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|||||||
if hasattr(quantized_model, "finalize_autoquant"):
|
if hasattr(quantized_model, "finalize_autoquant"):
|
||||||
print("finalizing autoquant")
|
print("finalizing autoquant")
|
||||||
quantized_model.finalize_autoquant()
|
quantized_model.finalize_autoquant()
|
||||||
|
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, quant_type: str, modules_to_not_convert: Optional[List] = None, **kwargs):
|
def __init__(
|
||||||
|
self,
|
||||||
|
quant_type: Union[str, "AOBaseConfig"], # noqa: F821
|
||||||
|
modules_to_not_convert: Optional[List] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
self.quant_method = QuantizationMethod.TORCHAO
|
self.quant_method = QuantizationMethod.TORCHAO
|
||||||
self.quant_type = quant_type
|
self.quant_type = quant_type
|
||||||
self.modules_to_not_convert = modules_to_not_convert
|
self.modules_to_not_convert = modules_to_not_convert
|
||||||
# when we load from serailized config, "quant_type_kwargs" will be the key
|
self.quant_type_kwargs = kwargs.get("quant_type_kwargs", kwargs)
|
||||||
if "quant_type_kwargs" in kwargs:
|
|
||||||
self.quant_type_kwargs = kwargs["quant_type_kwargs"]
|
|
||||||
else:
|
|
||||||
self.quant_type_kwargs = kwargs
|
|
||||||
|
|
||||||
self.post_init()
|
self.post_init()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_ao_version() -> version.Version:
|
||||||
|
"""Centralized check for TorchAO availability and version requirements."""
|
||||||
|
if not is_torchao_available():
|
||||||
|
raise ValueError("TorchAoConfig requires torchao to be installed. Install with `pip install torchao`")
|
||||||
|
|
||||||
|
return version.parse(importlib.metadata.version("torchao"))
|
||||||
|
|
||||||
def post_init(self):
|
def post_init(self):
|
||||||
r"""
|
"""Validate configuration and set defaults."""
|
||||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
ao_version = self._get_ao_version()
|
||||||
"""
|
|
||||||
if is_torchao_available():
|
# Handle quant_type based on type and version
|
||||||
if not version.parse(importlib.metadata.version("torchao")) >= version.parse("0.7.0"):
|
if isinstance(self.quant_type, str):
|
||||||
raise ValueError("Requires torchao 0.7.0 version and above")
|
self._validate_string_quant_type()
|
||||||
|
elif ao_version >= version.parse("0.10.0"):
|
||||||
|
from torchao.quantization.quant_api import AOBaseConfig
|
||||||
|
|
||||||
|
if not isinstance(self.quant_type, AOBaseConfig):
|
||||||
|
raise ValueError(
|
||||||
|
f"quant_type must be either a string or an AOBaseConfig instance, got {type(self.quant_type)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
f"In torchao < 0.10.0, quant_type must be a string. Got {type(self.quant_type)}. "
|
||||||
|
f"Please upgrade to torchao >= 0.10.0 to use AOBaseConfig instances."
|
||||||
)
|
)
|
||||||
|
|
||||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
def _validate_string_quant_type(self):
|
||||||
if self.quant_type not in _STR_TO_METHOD.keys():
|
"""Validate string quant_type and its kwargs."""
|
||||||
|
methods = self._get_torchao_quant_type_to_method()
|
||||||
|
|
||||||
|
if self.quant_type not in methods:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Requested quantization type: {self.quant_type} is not supported yet, please add support in TorchAoConfig and TorchAoHfQuantizer."
|
f"Unsupported string quantization type: {self.quant_type}. "
|
||||||
|
f"Supported types: {', '.join(methods.keys())}"
|
||||||
)
|
)
|
||||||
|
|
||||||
method = _STR_TO_METHOD[self.quant_type]
|
# Validate kwargs against method signature
|
||||||
|
method = methods[self.quant_type]
|
||||||
sig = signature(method)
|
sig = signature(method)
|
||||||
all_kwargs = [
|
valid_kwargs = {
|
||||||
param.name
|
param.name
|
||||||
for param in sig.parameters.values()
|
for param in sig.parameters.values()
|
||||||
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
|
if param.kind in [Parameter.KEYWORD_ONLY, Parameter.POSITIONAL_OR_KEYWORD]
|
||||||
]
|
}
|
||||||
for k in self.quant_type_kwargs:
|
|
||||||
if k not in all_kwargs:
|
invalid_kwargs = set(self.quant_type_kwargs) - valid_kwargs
|
||||||
raise ValueError(
|
if invalid_kwargs:
|
||||||
f"Unexpected keyword arg: {k} for API: {method}, accepted keyword args are: {all_kwargs}"
|
raise ValueError(
|
||||||
)
|
f"Unexpected keyword arg for {self.quant_type}: {', '.join(invalid_kwargs)}. "
|
||||||
|
f"Valid kwargs: {', '.join(valid_kwargs)}"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_torchao_quant_type_to_method(self):
|
def _get_torchao_quant_type_to_method(self):
|
||||||
if is_torchao_available():
|
"""Get mapping of quant_type strings to their corresponding methods."""
|
||||||
from torchao.quantization import (
|
from torchao.quantization import (
|
||||||
autoquant,
|
autoquant,
|
||||||
int4_weight_only,
|
int4_weight_only,
|
||||||
int8_dynamic_activation_int8_weight,
|
int8_dynamic_activation_int8_weight,
|
||||||
int8_weight_only,
|
int8_weight_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"int4_weight_only": int4_weight_only,
|
"int4_weight_only": int4_weight_only,
|
||||||
"int8_weight_only": int8_weight_only,
|
"int8_weight_only": int8_weight_only,
|
||||||
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
"int8_dynamic_activation_int8_weight": int8_dynamic_activation_int8_weight,
|
||||||
"autoquant": autoquant,
|
"autoquant": autoquant,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def get_quantize_config(self):
|
||||||
|
"""Create the appropriate quantization method based on configuration."""
|
||||||
|
if isinstance(self.quant_type, str):
|
||||||
|
methods = self._get_torchao_quant_type_to_method()
|
||||||
|
quant_type_kwargs = self.quant_type_kwargs.copy()
|
||||||
|
if (
|
||||||
|
not torch.cuda.is_available()
|
||||||
|
and is_torchao_available()
|
||||||
|
and self.quant_type == "int4_weight_only"
|
||||||
|
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
||||||
|
):
|
||||||
|
from torchao.dtypes import Int4CPULayout
|
||||||
|
|
||||||
|
quant_type_kwargs["layout"] = Int4CPULayout()
|
||||||
|
|
||||||
|
return methods[self.quant_type](**quant_type_kwargs)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
return self.quant_type
|
||||||
"TorchAoConfig requires torchao to be installed, please install with `pip install torchao`"
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_apply_tensor_subclass(self):
|
def to_dict(self):
|
||||||
_STR_TO_METHOD = self._get_torchao_quant_type_to_method()
|
"""Convert configuration to a dictionary."""
|
||||||
quant_type_kwargs = self.quant_type_kwargs.copy()
|
|
||||||
if (
|
|
||||||
not torch.cuda.is_available()
|
|
||||||
and is_torchao_available()
|
|
||||||
and self.quant_type == "int4_weight_only"
|
|
||||||
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
|
|
||||||
):
|
|
||||||
from torchao.dtypes import Int4CPULayout
|
|
||||||
|
|
||||||
quant_type_kwargs["layout"] = Int4CPULayout()
|
|
||||||
return _STR_TO_METHOD[self.quant_type](**quant_type_kwargs)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
config_dict = self.to_dict()
|
|
||||||
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Serializes this instance to a Python dictionary, converting any `torchao.dtypes.Layout`
|
|
||||||
dataclasses to simple dicts.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
|
||||||
"""
|
|
||||||
d = super().to_dict()
|
d = super().to_dict()
|
||||||
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
|
||||||
layout = d["quant_type_kwargs"]["layout"]
|
if isinstance(self.quant_type, str):
|
||||||
layout = dataclasses.asdict(layout)
|
# Handle layout serialization if present
|
||||||
d["quant_type_kwargs"]["layout"] = layout
|
if "quant_type_kwargs" in d and "layout" in d["quant_type_kwargs"]:
|
||||||
|
d["quant_type_kwargs"]["layout"] = dataclasses.asdict(d["quant_type_kwargs"]["layout"])
|
||||||
|
else:
|
||||||
|
# Handle AOBaseConfig serialization
|
||||||
|
from torchao.core.config import config_to_dict
|
||||||
|
|
||||||
|
# For now we assume there is 1 config per Transfomer, however in the future
|
||||||
|
# We may want to support a config per fqn.
|
||||||
|
d["quant_type"] = {"default": config_to_dict(self.quant_type)}
|
||||||
|
|
||||||
return d
|
return d
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
|
||||||
|
"""Create configuration from a dictionary."""
|
||||||
|
ao_verison = cls._get_ao_version()
|
||||||
|
assert ao_verison >= version.parse(
|
||||||
|
"0.10.0"
|
||||||
|
), "TorchAoConfig requires torchao >= 0.10.0 for construction from dict"
|
||||||
|
config_dict = config_dict.copy()
|
||||||
|
quant_type = config_dict.pop("quant_type")
|
||||||
|
# Check if we only have one key which is "default"
|
||||||
|
# In the future we may update this
|
||||||
|
assert (
|
||||||
|
len(quant_type) == 1 and "default" in quant_type
|
||||||
|
), "Expected only one key 'default' in quant_type dictionary"
|
||||||
|
quant_type = quant_type["default"]
|
||||||
|
|
||||||
|
# Deserialize quant_type if needed
|
||||||
|
from torchao.core.config import config_from_dict
|
||||||
|
|
||||||
|
quant_type = config_from_dict(quant_type)
|
||||||
|
|
||||||
|
return cls(quant_type=quant_type, **config_dict)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class BitNetConfig(QuantizationConfigMixin):
|
class BitNetConfig(QuantizationConfigMixin):
|
||||||
|
@ -85,7 +85,7 @@ class TorchAoConfigTest(unittest.TestCase):
|
|||||||
Test kwargs validations in TorchAoConfig
|
Test kwargs validations in TorchAoConfig
|
||||||
"""
|
"""
|
||||||
_ = TorchAoConfig("int4_weight_only")
|
_ = TorchAoConfig("int4_weight_only")
|
||||||
with self.assertRaisesRegex(ValueError, "is not supported yet"):
|
with self.assertRaisesRegex(ValueError, "Unsupported string quantization type"):
|
||||||
_ = TorchAoConfig("fp6")
|
_ = TorchAoConfig("fp6")
|
||||||
|
|
||||||
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
|
with self.assertRaisesRegex(ValueError, "Unexpected keyword arg"):
|
||||||
@ -408,5 +408,41 @@ class TorchAoSerializationW8GPUTest(TorchAoSerializationTest):
|
|||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torchao_version_greater_or_equal("0.10.0")
|
||||||
|
class TorchAoSerializationFP8GPUTest(TorchAoSerializationTest):
|
||||||
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||||
|
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||||
|
|
||||||
|
from torchao.quantization import Float8WeightOnlyConfig
|
||||||
|
|
||||||
|
self.quant_scheme = Float8WeightOnlyConfig()
|
||||||
|
self.quant_scheme_kwargs = {}
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
|
||||||
|
@require_torch_gpu
|
||||||
|
@require_torchao_version_greater_or_equal("0.10.0")
|
||||||
|
class TorchAoSerializationA8W4Test(TorchAoSerializationTest):
|
||||||
|
ORIGINAL_EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
|
||||||
|
SERIALIZED_EXPECTED_OUTPUT = ORIGINAL_EXPECTED_OUTPUT
|
||||||
|
device = "cuda:0"
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 9:
|
||||||
|
raise unittest.SkipTest("CUDA compute capability 9.0 or higher required for FP8 tests")
|
||||||
|
|
||||||
|
from torchao.quantization import Int8DynamicActivationInt4WeightConfig
|
||||||
|
|
||||||
|
self.quant_scheme = Int8DynamicActivationInt4WeightConfig()
|
||||||
|
self.quant_scheme_kwargs = {}
|
||||||
|
super().setUp()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user