Add HQQ quantization support (#29637)

* update HQQ transformers integration

* push import_utils.py

* add force_hooks check in modeling_utils.py

* fix | with Optional

* force bias as param

* check bias is Tensor

* force forward for multi-gpu

* review fixes pass

* remove torch grad()

* if any key in linear_tags fix

* add cpu/disk check

* isinstance return

* add multigpu test + refactor tests

* clean hqq_utils imports in hqq.py

* clean hqq_utils imports in quantizer_hqq.py

* delete hqq_utils.py

* Delete src/transformers/utils/hqq_utils.py

* ruff init

* remove torch.float16 from __init__ in test

* refactor test

* isinstance -> type in quantizer_hqq.py

* cpu/disk device_map check in quantizer_hqq.py

* remove type(module) nn.linear check in quantizer_hqq.py

* add BaseQuantizeConfig import inside HqqConfig init

* remove hqq import in hqq.py

* remove accelerate import from test_hqq.py

* quant config.py doc update

* add hqqconfig to main_classes doc

* make style

* __init__ fix

* ruff __init__

* skip_modules list

* hqqconfig format fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* hqqconfig doc fix

* test_hqq.py remove mistral comment

* remove self.using_multi_gpu is False

* torch_dtype default val set and logger.info

* hqq.py isinstance fix

* remove torch=None

* torch_device test_hqq

* rename test_hqq

* MODEL_ID in test_hqq

* quantizer_hqq setattr fix

* quantizer_hqq typo fix

* imports quantizer_hqq.py

* isinstance quantizer_hqq

* hqq_layer.bias reformat quantizer_hqq

* Step 2 as comment in quantizer_hqq

* prepare_for_hqq_linear() comment

* keep_in_fp32_modules fix

* HqqHfQuantizer reformat

* quantization.md hqqconfig

* quantization.md model example reformat

* quantization.md # space

* quantization.md space   })

* quantization.md space   })

* quantization_config fix doc

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* axis value check in quantization_config

* format

* dynamic config explanation

* quant config method in quantization.md

* remove shard-level progress

* .cuda fix modeling_utils

* test_hqq fixes

* make fix-copies

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
mobicham 2024-05-02 18:51:49 +02:00 committed by GitHub
parent 4c940934da
commit 59952994c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 681 additions and 1 deletions

3
docker/transformers-quantization-latest-gpu/Dockerfile Normal file → Executable file
View File

@ -45,6 +45,9 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
# Add aqlm for quantization testing
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
# Add hqq for quantization testing
RUN python3 -m pip install --no-cache-dir hqq
# Add autoawq for quantization testing
# >=v0.2.3 needed for compatibility with torch 2.2.1
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.3/autoawq-0.2.3+cu118-cp38-cp38-linux_x86_64.whl

4
docs/source/en/main_classes/quantization.md Normal file → Executable file
View File

@ -52,3 +52,7 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
## HfQuantizer
[[autodoc]] quantizers.base.HfQuantizer
## HqqConfig
[[autodoc]] HqqConfig

50
docs/source/en/quantization.md Normal file → Executable file
View File

@ -745,3 +745,53 @@ The speed and throughput of fused and unfused modules were also tested with the
<figcaption class="mt-2 text-center text-sm text-gray-500">generate throughput/batch size</figcaption>
</div>
</div>
## HQQ
Half-Quadratic Quantization (HQQ) implements on-the-fly quantization via fast robust optimization. It doesn't require calibration data and can be used to quantize any model.
Please refer to the <a href="https://github.com/mobiusml/hqq/">official package</a> for more details.
For installation, we recommend you use the following approach to get the latest version and build its corresponding CUDA kernels:
```
pip install hqq
```
To quantize a model, you need to create an [`HqqConfig`]. There are two ways of doing it:
``` Python
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
# Method 1: all linear layers will use the same quantization config
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
```
``` Python
# Method 2: each linear layer with the same tag will use a dedicated quantization config
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
quant_config = HqqConfig(dynamic_config={
'self_attn.q_proj':q4_config,
'self_attn.k_proj':q4_config,
'self_attn.v_proj':q4_config,
'self_attn.o_proj':q4_config,
'mlp.gate_proj':q3_config,
'mlp.up_proj' :q3_config,
'mlp.down_proj':q3_config,
})
```
The second approach is especially interesting for quantizing Mixture-of-Experts (MoEs) because the experts are less affected by lower quantization settings.
Then you simply quantize the model as follows
``` Python
model = transformers.AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="cuda",
quantization_config=quant_config
)
```
### Optimized Runtime
HQQ supports various backends, including pure Pytorch and custom dequantization CUDA kernels. These backends are suitable for older gpus and peft/QLoRA training.
For faster inference, HQQ supports 4-bit fused kernels (TorchAO and Marlin), reaching up to 200 tokens/sec on a single 4090.
For more details on how to use the backends, please refer to https://github.com/mobiusml/hqq/?tab=readme-ov-file#backend

0
docs/source/en/quicktour.md Normal file → Executable file
View File

2
src/transformers/__init__.py Normal file → Executable file
View File

@ -1133,6 +1133,7 @@ _import_structure = {
"BitsAndBytesConfig",
"EetqConfig",
"GPTQConfig",
"HqqConfig",
"QuantoConfig",
],
}
@ -6099,6 +6100,7 @@ if TYPE_CHECKING:
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
HqqConfig,
QuantoConfig,
)

2
src/transformers/integrations/__init__.py Normal file → Executable file
View File

@ -43,6 +43,7 @@ _import_structure = {
"unset_hf_deepspeed_config",
],
"eetq": ["replace_with_eetq_linear"],
"hqq": ["prepare_for_hqq_linear"],
"integration_utils": [
"INTEGRATION_TO_CALLBACK",
"AzureMLCallback",
@ -113,6 +114,7 @@ if TYPE_CHECKING:
unset_hf_deepspeed_config,
)
from .eetq import replace_with_eetq_linear
from .hqq import prepare_for_hqq_linear
from .integration_utils import (
INTEGRATION_TO_CALLBACK,
AzureMLCallback,

View File

@ -0,0 +1,121 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"HQQ (Half-Quadratic Quantization) integration file"
from ..utils import is_hqq_available, is_torch_available, logging
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
# Name all modules inside the model
def autoname_modules(model):
for name, module in model.named_modules():
module.name = name
# Get the linear_tag from a modul name. For example: model.layers.31.self_attn.k_proj -> self_attn.k_proj
def name_to_linear_tag(name):
return ".".join([n for n in name.split(".") if ((n not in ["model", "layers"]) and (not n.isnumeric()))])
# Get all linear tags available
def get_linear_tags(model):
if is_hqq_available():
from hqq.core.quantize import HQQLinear
linear_tags = set()
for name, module in model.named_modules():
if isinstance(module, (torch.nn.Linear, HQQLinear)):
linear_tags.add(name_to_linear_tag(name))
return list(linear_tags)
def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_name=None):
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, torch.nn.Linear):
# Get linear tag
linear_tag = name_to_linear_tag(module.name)
# We put the module quant_config into the nn.Linear layer so we can access it later in quantizer_hqq.create_quantized_param()
if linear_tag in patch_params:
if patch_params[linear_tag] is not None:
model._modules[name].quant_config = patch_params[linear_tag]
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
has_been_replaced = True
if len(list(module.children())) > 0:
_, has_been_replaced = _prepare_for_hqq_linear(
module,
patch_params=patch_params,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced
def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_convert=None, has_been_replaced=False):
"""
Prepares nn.Linear layers for HQQ quantization.
Since each layer type can have separate quantization parameters, we need to do the following:
1- tag each module with its neme via autoname_modules()
2- Extract linear_tags (e.g. ['self_attn.q_proj', ...])
3- Map quantization parameters as a dictionary linear_tag -> quant_params as HQQLinear exepects it, this is referred to as patch_params
"""
modules_to_not_convert = [] if modules_to_not_convert is None else modules_to_not_convert
# Add name to module
autoname_modules(model)
# Get linear tags. This allows us to use different quant params to different layer types
linear_tags = get_linear_tags(model)
# Convert quantization_config to layer-wise config
skip_modules = quantization_config.skip_modules
quant_config = quantization_config.to_dict()
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
if any(key in linear_tags for key in quant_config.keys()):
# If the user doesn't specify a key from get_linear_tags, the layer is not quantized via (key, None)
patch_params = {key: None for key in linear_tags}
patch_params.update(quant_config)
else:
# Same quant_config for all layers
patch_params = {k: quant_config for k in linear_tags}
model, has_been_replaced = _prepare_for_hqq_linear(
model, patch_params=patch_params, has_been_replaced=has_been_replaced
)
# We store quantization config as linear_tag -> hqq quant config
model.config.quantization_config = patch_params
if not has_been_replaced:
logger.warning("No linear modules were found in your model for quantization.")
return model

0
src/transformers/integrations/integration_utils.py Normal file → Executable file
View File

11
src/transformers/modeling_utils.py Normal file → Executable file
View File

@ -2659,6 +2659,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@wraps(torch.nn.Module.cuda)
def cuda(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.cuda` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
@ -2670,6 +2672,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
@wraps(torch.nn.Module.to)
def to(self, *args, **kwargs):
if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
raise ValueError("`.to` is not supported for HQQ-quantized models.")
# Checks if the model has been loaded in 8-bit
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
raise ValueError(
@ -3739,6 +3743,13 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
}
if "skip_keys" in inspect.signature(dispatch_model).parameters:
device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
# For HQQ method we force-set the hooks for single GPU envs
if (
"force_hooks" in inspect.signature(dispatch_model).parameters
and hf_quantizer is not None
and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
):
device_map_kwargs["force_hooks"] = True
if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
dispatch_model(model, **device_map_kwargs)

0
src/transformers/quantizers/__init__.py Normal file → Executable file
View File

4
src/transformers/quantizers/auto.py Normal file → Executable file
View File

@ -21,6 +21,7 @@ from ..utils.quantization_config import (
BitsAndBytesConfig,
EetqConfig,
GPTQConfig,
HqqConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
@ -31,6 +32,7 @@ from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
from .quantizer_eetq import EetqHfQuantizer
from .quantizer_gptq import GptqHfQuantizer
from .quantizer_hqq import HqqHfQuantizer
from .quantizer_quanto import QuantoHfQuantizer
@ -42,6 +44,7 @@ AUTO_QUANTIZER_MAPPING = {
"aqlm": AqlmHfQuantizer,
"quanto": QuantoHfQuantizer,
"eetq": EetqHfQuantizer,
"hqq": HqqHfQuantizer,
}
AUTO_QUANTIZATION_CONFIG_MAPPING = {
@ -52,6 +55,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"gptq": GPTQConfig,
"aqlm": AqlmConfig,
"quanto": QuantoConfig,
"hqq": HqqConfig,
}

View File

@ -0,0 +1,200 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Any, Dict, List
from ..integrations import prepare_for_hqq_linear
from ..utils import is_accelerate_available, is_hqq_available, is_torch_available, logging
from .base import HfQuantizer
from .quantizers_utils import get_module_from_name
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
if is_accelerate_available():
from accelerate.hooks import remove_hook_from_module
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
# Finds the parent of a node module named "name"
def find_parent(model, name):
module_tree = name.split(".")[:-1]
parent = model
for m in module_tree:
parent = parent._modules[m]
return parent
class HqqHfQuantizer(HfQuantizer):
"""
HQQ quantizer base HF class.
nn.Linear modules are first tagged with quant_config in _process_model_before_weight_loading().
The actual quantization and offloading to the GPU is done in check_quantized_param().
"""
use_keep_in_fp32_modules = False
requires_parameters_quantization = True
requires_calibration = False
required_packages = ["hqq"]
def __init__(self, quantization_config, **kwargs):
super().__init__(quantization_config, **kwargs)
self.torch_dtype = None
self.using_multi_gpu = False
def validate_environment(self, *args, **kwargs):
if not (is_hqq_available()):
raise ImportError(
"HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`"
)
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
raise ValueError(
"Converting weights from tf/flax weights is currently not supported, please make"
" sure the weights are in PyTorch format."
)
if not torch.cuda.is_available():
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
if self.torch_dtype is None:
if "torch_dtype" in kwargs:
self.torch_dtype = kwargs["torch_dtype"]
else:
self.torch_dtype = torch.float32
logger.info("Setting torch_dtype to torch.float32 as the default value since it was not specified.")
device_map = kwargs.get("device_map", None)
if isinstance(device_map, dict):
if "cpu" in device_map.values() or "disk" in device_map.values():
raise ValueError(
"You are attempting to use an HQQ model with a device_map that contains a CPU or disk device."
" This is not supported. Please remove the CPU or disk device from the device_map."
)
else:
self.using_multi_gpu = len(set(device_map.values())) > 1
def check_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
state_dict: Dict[str, Any],
**kwargs,
) -> bool:
module, tensor_name = get_module_from_name(model, param_name)
return isinstance(module, torch.nn.Linear)
def create_quantized_param(
self,
model: "PreTrainedModel",
param_value: "torch.Tensor",
param_name: str,
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
):
"""
Each nn.Linear layer is processsed here.
We first check if the corresponding module state_dict contains already HQQ quantized parameters.
If not, we create a temp linear layer with the module state_dict params and use it for quantization
"""
if is_hqq_available():
from hqq.core.quantize import HQQLinear
module, tensor_name = get_module_from_name(model, param_name)
layer_name = param_name.replace(".weight", "").replace(".bias", "")
parent_module = find_parent(model, layer_name)
node = layer_name.split(".")[-1]
# Step 0: set module state_dict
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
# Step 1: populate module with weight/bias from module state dict
for key in module_state_dict:
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
# directly doesn't work.
if hasattr(module, "quant_config"):
hqq_layer = HQQLinear(
module,
module.quant_config,
compute_dtype=self.torch_dtype,
device=target_device,
del_orig=True,
)
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
if self.using_multi_gpu:
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
setattr(parent_module, node, hqq_layer)
else:
module = module.to(dtype=self.torch_dtype, device=target_device)
setattr(parent_module, node, module)
torch.cuda.empty_cache()
# Remove accelerate hook and uses a simpler forward pass. Otherwise, this breaks with multi-gpu
def _patch_layer_for_multigpu(self, hqq_layer):
hqq_layer = remove_hook_from_module(hqq_layer)
def forward_with_device(self, x):
out = torch.matmul(x.to(self.device), self.dequantize().t())
if self.bias is not None:
out += self.bias
return out
hqq_layer.forward = lambda x: forward_with_device(hqq_layer, x)
return hqq_layer
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
device_map,
keep_in_fp32_modules: List[str] = None,
**kwargs,
):
keep_in_fp32_modules = keep_in_fp32_modules if keep_in_fp32_modules is not None else []
# Add the corresponding quant_config to each valid module. This allows us to do the actual nn.Linear -> HQQLinear conversion in create_quantized_param().
# prepare_for_hqq_linear() also sets the right quantization config inside the model (model.config.quantization_config) and the layers (hqq_layer.quant_config)
model = prepare_for_hqq_linear(model, quantization_config=self.quantization_config)
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model.is_hqq_quantized = True
model.is_hqq_serializable = self.is_serializable
return model
@property
def is_serializable(self):
return False
@property
def is_trainable(self) -> bool:
return False

1
src/transformers/utils/__init__.py Normal file → Executable file
View File

@ -129,6 +129,7 @@ from .import_utils import (
is_ftfy_available,
is_g2p_en_available,
is_galore_torch_available,
is_hqq_available,
is_in_notebook,
is_ipex_available,
is_jieba_available,

5
src/transformers/utils/import_utils.py Normal file → Executable file
View File

@ -170,6 +170,7 @@ _torchaudio_available = _is_package_available("torchaudio")
_torchdistx_available = _is_package_available("torchdistx")
_torchvision_available = _is_package_available("torchvision")
_mlx_available = _is_package_available("mlx")
_hqq_available = _is_package_available("hqq")
_torch_version = "N/A"
@ -292,6 +293,10 @@ def is_torch_available():
return _torch_available
def is_hqq_available():
return _hqq_available
def get_torch_version():
return _torch_version

112
src/transformers/utils/quantization_config.py Normal file → Executable file
View File

@ -24,7 +24,7 @@ from typing import Any, Dict, List, Optional, Union
from packaging import version
from ..utils import is_auto_awq_available, is_torch_available, logging
from ..utils import is_auto_awq_available, is_hqq_available, is_torch_available, logging
if is_torch_available():
@ -41,6 +41,7 @@ class QuantizationMethod(str, Enum):
AQLM = "aqlm"
QUANTO = "quanto"
EETQ = "eetq"
HQQ = "hqq"
class AWQLinearVersion(str, Enum):
@ -180,6 +181,115 @@ class QuantizationConfigMixin:
return unused_kwargs
@dataclass
class HqqConfig(QuantizationConfigMixin):
"""
This is wrapper around hqq's BaseQuantizeConfig.
Args:
nbits (`int`, *optional*, defaults to 4):
Number of bits. Supported values are (8, 4, 3, 2, 1).
group_size (`int`, *optional*, defaults to 64):
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
quant_zero (`bool`, *optional*, defaults to `True`):
Quantize the zero-point if set to `True`.
quant_scale (`bool`, *optional*, defaults to `False`):
Quantize the scaling if set to `True`.
offload_meta (`bool`, *optional*, defaults to `False`):
Offload the meta-data to the CPU if set to `True`.
view_as_float (`bool`, *optional*, defaults to `False`):
View the quantized weight as float (used in distributed training) if set to `True`.
axis (`int`, *optional*, defaults to 0):
Axis along which grouping is performed. Supported values are 0 or 1.
dynamic_config (dict, *optional*):
Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
If set, each layer specified by its id will use its dedicated quantization configuration.
skip_modules (`List[str]`, *optional*, defaults to `['lm_head']`):
List of `nn.Linear` layers to skip.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
"""
def __init__(
self,
nbits: int = 4,
group_size: int = 64,
quant_zero: bool = True,
quant_scale: bool = False,
offload_meta: bool = False,
view_as_float: bool = False,
axis: int = 0,
dynamic_config: Optional[dict] = None,
skip_modules: List[str] = ["lm_head"],
**kwargs,
):
if is_hqq_available():
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
if axis not in [0, 1]:
raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
if dynamic_config is not None:
self.quant_config = {}
for key in dynamic_config:
self.quant_config[key] = HQQBaseQuantizeConfig(**dynamic_config[key])
else:
self.quant_config = HQQBaseQuantizeConfig(
**{
"nbits": nbits,
"group_size": group_size,
"quant_zero": quant_zero,
"quant_scale": quant_scale,
"offload_meta": offload_meta,
"view_as_float": view_as_float,
"axis": axis,
}
)
self.quant_method = QuantizationMethod.HQQ
self.skip_modules = skip_modules
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
pass
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
return self.quant_config
def __repr__(self):
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict = self.to_dict()
# get the default config dict
default_config_dict = HqqConfig().to_dict()
serializable_config_dict = {}
# only serialize values that differ from the default config
for key, value in config_dict.items():
if value != default_config_dict[key]:
serializable_config_dict[key] = value
return serializable_config_dict
@dataclass
class BitsAndBytesConfig(QuantizationConfigMixin):
"""

View File

@ -0,0 +1,167 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
from transformers.testing_utils import (
require_accelerate,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_hqq_available, is_torch_available
if is_torch_available():
import torch
if is_hqq_available():
from hqq.core.quantize import HQQBackend, HQQLinear
class HQQLLMRunner:
def __init__(self, model_id, quant_config, compute_dtype, device, cache_dir):
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=compute_dtype,
device_map=device,
quantization_config=quant_config,
low_cpu_mem_usage=True,
cache_dir=cache_dir,
)
self.tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=cache_dir)
self.device = self.model.device
HQQLinear.set_backend(HQQBackend.PYTORCH)
def cleanup():
torch.cuda.empty_cache()
gc.collect()
def check_hqqlayer(test_module, hqq_layer, batch_size=1, context_size=1024):
# Test HQQ layer
W_dequant = hqq_layer.dequantize() # Reconstructed weights
inputs = (
torch.randn(
(batch_size, context_size, hqq_layer.meta["shape"][1]),
device=hqq_layer.device,
dtype=hqq_layer.compute_dtype,
)
/ 10.0
)
with torch.no_grad():
outputs = hqq_layer(inputs)
test_module.assertEqual(outputs.shape[-1], W_dequant.shape[0])
test_module.assertEqual(outputs.dtype, hqq_layer.compute_dtype)
del W_dequant, inputs, outputs
cleanup()
def check_forward(test_module, model, batch_size=1, context_size=1024):
# Test forward pass
with torch.no_grad():
out = model(torch.zeros([batch_size, context_size], device=model.device, dtype=torch.int32)).logits
test_module.assertEqual(out.shape[0], batch_size)
test_module.assertEqual(out.shape[1], context_size)
cleanup()
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
@require_torch_gpu
class HqqConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Makes sure the config format is properly set
"""
quantization_config = HqqConfig()
hqq_orig_config = quantization_config.to_dict()
for key in hqq_orig_config:
self.assertEqual(quantization_config.quant_config[key], hqq_orig_config[key])
@slow
@require_torch_gpu
@require_accelerate
class HQQTest(unittest.TestCase):
def tearDown(self):
cleanup()
def test_fp16_quantized_model(self):
"""
Simple LLM model testing fp16
"""
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device=torch_device
)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
def test_bfp16_quantized_model_with_offloading(self):
"""
Simple LLM model testing bfp16 with meta-data offloading
"""
q4_config = {"nbits": 4, "group_size": 64, "quant_zero": False, "quant_scale": False}
q3_config = {"nbits": 3, "group_size": 32, "quant_zero": False, "quant_scale": False, "offload_meta": True}
quant_config = HqqConfig(
dynamic_config={
"self_attn.q_proj": q4_config,
"self_attn.k_proj": q4_config,
"self_attn.v_proj": q4_config,
"self_attn.o_proj": q4_config,
"mlp.gate_proj": q3_config,
"mlp.up_proj": q3_config,
"mlp.down_proj": q3_config,
}
)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.bfloat16, device=torch_device
)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)
@slow
@require_torch_gpu
@require_torch_multi_gpu
@require_accelerate
class HQQTestMultiGPU(unittest.TestCase):
def tearDown(self):
cleanup()
def test_fp16_quantized_model_multipgpu(self):
"""
Simple LLM model testing fp16 with multi-gpu
"""
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0)
hqq_runner = HQQLLMRunner(
model_id=MODEL_ID, quant_config=quant_config, compute_dtype=torch.float16, device="auto"
)
check_hqqlayer(self, hqq_runner.model.model.layers[0].self_attn.v_proj)
check_forward(self, hqq_runner.model)