mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
AQLM quantizer support (#28928)
* aqlm init * calibration and dtypes * docs * Readme update * is_aqlm_available * Simpler link in docs * Test TODO real reference * init _import_structure fix * AqlmConfig autodoc * integration aqlm * integrations in tests * docstring fix * legacy typing * Less typings * More kernels information * Performance -> Accuracy * correct tests * remoced multi-gpu test * Update docs/source/en/quantization.md Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> * Update src/transformers/utils/quantization_config.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * Brought back multi-gpu tests * Update src/transformers/integrations/aqlm.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * Update tests/quantization/aqlm_integration/test_aqlm.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --------- Co-authored-by: Andrei Panferov <blacksamorez@yandex-team.ru> Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
63ffd56d02
commit
1ecf5f7c98
@ -55,6 +55,9 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
|
||||
# Add einops for additional model testing
|
||||
RUN python3 -m pip install --no-cache-dir einops
|
||||
|
||||
# Add aqlm for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.1
|
||||
|
||||
# Add autoawq for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl
|
||||
|
||||
|
@ -26,6 +26,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
|
||||
|
||||
</Tip>
|
||||
|
||||
## AqlmConfig
|
||||
|
||||
[[autodoc]] AqlmConfig
|
||||
|
||||
## AwqConfig
|
||||
|
||||
[[autodoc]] AwqConfig
|
||||
|
@ -26,6 +26,34 @@ Interested in adding a new quantization method to Transformers? Read the [HfQuan
|
||||
|
||||
</Tip>
|
||||
|
||||
## AQLM
|
||||
|
||||
|
||||
|
||||
Try AQLM on [Google Colab](https://colab.research.google.com/drive/1-xZmBRXT5Fm3Ghn4Mwa2KRypORXb855X?usp=sharing)!
|
||||
|
||||
Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and take advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes.
|
||||
|
||||
Inference support for AQLM is realised in the `aqlm` library. Make sure to install it to run the models (note aqlm works only with python>=3.10):
|
||||
```bash
|
||||
pip install aqlm[gpu,cpu]
|
||||
```
|
||||
|
||||
The library provides efficient kernels for both GPU and CPU inference.
|
||||
|
||||
The instructions on how to quantize models yourself, as well as all the relevant code can be found in the corresponding GitHub [repository](https://github.com/Vahe1994/AQLM).
|
||||
|
||||
### AQLM configurations
|
||||
|
||||
AQLM quantization setpus vary mainly on the number of codebooks used as well as codebook sizes in bits. The most popular setups, as well as inference kernels they support are:
|
||||
|
||||
| Kernel | Number of codebooks | Codebook size, bits | Notation | Accuracy | Speedup | Fast GPU inference | Fast CPU inference |
|
||||
|---|---------------------|---------------------|----------|-------------|-------------|--------------------|--------------------|
|
||||
| Triton | K | N | KxN | - | Up to ~0.7x | ✅ | ❌ |
|
||||
| CUDA | 1 | 16 | 1x16 | Best | Up to ~1.3x | ✅ | ❌ |
|
||||
| CUDA | 2 | 8 | 2x8 | OK | Up to ~3.0x | ✅ | ❌ |
|
||||
| Numba | K | 8 | Kx8 | Good | Up to ~4.0x | ❌ | ✅ |
|
||||
|
||||
## AWQ
|
||||
|
||||
<Tip>
|
||||
|
@ -1087,7 +1087,7 @@ _import_structure = {
|
||||
"is_vision_available",
|
||||
"logging",
|
||||
],
|
||||
"utils.quantization_config": ["AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
|
||||
"utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
|
||||
}
|
||||
|
||||
# sentencepiece-backed objects
|
||||
@ -5845,7 +5845,7 @@ if TYPE_CHECKING:
|
||||
)
|
||||
|
||||
# bitsandbytes config
|
||||
from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig
|
||||
from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig
|
||||
|
||||
try:
|
||||
if not is_sentencepiece_available():
|
||||
|
@ -17,6 +17,7 @@ from ..utils import _LazyModule
|
||||
|
||||
|
||||
_import_structure = {
|
||||
"aqlm": ["replace_with_aqlm_linear"],
|
||||
"awq": ["fuse_awq_modules", "replace_with_awq_linear"],
|
||||
"bitsandbytes": [
|
||||
"get_keys_to_not_convert",
|
||||
@ -80,6 +81,7 @@ _import_structure = {
|
||||
}
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .aqlm import replace_with_aqlm_linear
|
||||
from .awq import fuse_awq_modules, replace_with_awq_linear
|
||||
from .bitsandbytes import (
|
||||
get_keys_to_not_convert,
|
||||
|
99
src/transformers/integrations/aqlm.py
Normal file
99
src/transformers/integrations/aqlm.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"AQLM (Additive Quantization of Language Model) integration file"
|
||||
|
||||
|
||||
from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def replace_with_aqlm_linear(
|
||||
model,
|
||||
quantization_config=None,
|
||||
linear_weights_not_to_quantize=None,
|
||||
current_key_name=None,
|
||||
has_been_replaced=False,
|
||||
):
|
||||
"""
|
||||
Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers.
|
||||
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
|
||||
conversion has been successfull or not.
|
||||
|
||||
Args:
|
||||
model (`torch.nn.Module`):
|
||||
The model to convert, can be any `torch.nn.Module` instance.
|
||||
quantization_config (`AqlmConfig`):
|
||||
The quantization config object that contains the quantization parameters.
|
||||
linear_weights_not_to_quantize (`list[str]`, *optional*):
|
||||
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
|
||||
converted.
|
||||
current_key_name (`list`, *optional*):
|
||||
A list that contains the current key name. This is used for recursion and should not be passed by the user.
|
||||
has_been_replaced (`bool`, *optional*):
|
||||
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
|
||||
should not be passed by the user.
|
||||
"""
|
||||
if not is_aqlm_available():
|
||||
raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")
|
||||
|
||||
if not is_accelerate_available():
|
||||
raise ValueError("AQLM requires Accelerate to be installed: `pip install accelerate`")
|
||||
|
||||
if linear_weights_not_to_quantize is None:
|
||||
linear_weights_not_to_quantize = []
|
||||
|
||||
from accelerate import init_empty_weights
|
||||
from aqlm import QuantizedLinear
|
||||
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
current_key_name.append(name)
|
||||
|
||||
if isinstance(module, nn.Linear):
|
||||
# Check if the current key is not in the `linear_weights_not_to_quantize`
|
||||
if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
|
||||
with init_empty_weights():
|
||||
in_features = module.in_features
|
||||
out_features = module.out_features
|
||||
|
||||
model._modules[name] = QuantizedLinear(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=module.bias is not None,
|
||||
in_group_size=quantization_config.in_group_size,
|
||||
out_group_size=quantization_config.out_group_size,
|
||||
num_codebooks=quantization_config.num_codebooks,
|
||||
nbits_per_codebook=quantization_config.nbits_per_codebook,
|
||||
)
|
||||
has_been_replaced = True
|
||||
|
||||
# Store the module class in case we need to transpose the weight later
|
||||
model._modules[name].source_cls = type(module)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
if len(list(module.children())) > 0:
|
||||
_, has_been_replaced = replace_with_aqlm_linear(
|
||||
module,
|
||||
quantization_config=quantization_config,
|
||||
linear_weights_not_to_quantize=linear_weights_not_to_quantize,
|
||||
current_key_name=current_key_name,
|
||||
has_been_replaced=has_been_replaced,
|
||||
)
|
||||
# Remove the last key for recursion
|
||||
current_key_name.pop(-1)
|
||||
return model, has_been_replaced
|
@ -16,12 +16,14 @@ from typing import Dict, Optional, Union
|
||||
|
||||
from ..models.auto.configuration_auto import AutoConfig
|
||||
from ..utils.quantization_config import (
|
||||
AqlmConfig,
|
||||
AwqConfig,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig,
|
||||
QuantizationConfigMixin,
|
||||
QuantizationMethod,
|
||||
)
|
||||
from .quantizer_aqlm import AqlmHfQuantizer
|
||||
from .quantizer_awq import AwqQuantizer
|
||||
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
|
||||
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
|
||||
@ -33,6 +35,7 @@ AUTO_QUANTIZER_MAPPING = {
|
||||
"bitsandbytes_4bit": Bnb4BitHfQuantizer,
|
||||
"bitsandbytes_8bit": Bnb8BitHfQuantizer,
|
||||
"gptq": GptqHfQuantizer,
|
||||
"aqlm": AqlmHfQuantizer,
|
||||
}
|
||||
|
||||
AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
@ -40,6 +43,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
||||
"bitsandbytes_4bit": BitsAndBytesConfig,
|
||||
"bitsandbytes_8bit": BitsAndBytesConfig,
|
||||
"gptq": GPTQConfig,
|
||||
"aqlm": AqlmConfig,
|
||||
}
|
||||
|
||||
|
||||
|
89
src/transformers/quantizers/quantizer_aqlm.py
Normal file
89
src/transformers/quantizers/quantizer_aqlm.py
Normal file
@ -0,0 +1,89 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from .base import HfQuantizer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from ..integrations import replace_with_aqlm_linear
|
||||
from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available, logging
|
||||
from ..utils.quantization_config import QuantizationConfigMixin
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class AqlmHfQuantizer(HfQuantizer):
|
||||
"""
|
||||
Quantizer of the AQLM method. Enables the loading of prequantized models.
|
||||
"""
|
||||
|
||||
requires_calibration = True
|
||||
required_packages = ["aqlm"]
|
||||
optimum_quantizer = None
|
||||
|
||||
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
|
||||
super().__init__(quantization_config, **kwargs)
|
||||
self.quantization_config = quantization_config
|
||||
|
||||
def validate_environment(self, *args, **kwargs):
|
||||
if not is_accelerate_available():
|
||||
raise ImportError("Using `aqlm` quantization requires Accelerate: `pip install accelerate`")
|
||||
|
||||
if not is_aqlm_available():
|
||||
raise ImportError("Using `aqlm` quantization requires AQLM: `pip install aqlm[gpu,cpu]`")
|
||||
|
||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
|
||||
if torch_dtype is None:
|
||||
if torch.cuda.is_available():
|
||||
torch_dtype = torch.float16
|
||||
logger.info(
|
||||
"CUDA available. Assuming AQLM inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually."
|
||||
)
|
||||
else:
|
||||
torch_dtype = torch.float32
|
||||
logger.info(
|
||||
"CUDA is unavailable. Assuming AQLM inference on CPU and loading the model in `torch.float32`. To overwrite it, set `torch_dtype` manually."
|
||||
)
|
||||
return torch_dtype
|
||||
|
||||
def _process_model_before_weight_loading(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
**kwargs,
|
||||
):
|
||||
replace_with_aqlm_linear(
|
||||
model,
|
||||
quantization_config=self.quantization_config,
|
||||
linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize,
|
||||
)
|
||||
model.config.quantization_config = self.quantization_config
|
||||
|
||||
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
|
||||
model._is_quantized_training_enabled = False
|
||||
return model
|
||||
|
||||
@property
|
||||
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
return True
|
@ -53,6 +53,7 @@ from .integrations.deepspeed import is_deepspeed_available
|
||||
from .utils import (
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_aqlm_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_bitsandbytes_available,
|
||||
@ -956,6 +957,13 @@ def require_apex(test_case):
|
||||
return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
|
||||
|
||||
|
||||
def require_aqlm(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires aqlm
|
||||
"""
|
||||
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
|
||||
|
||||
|
||||
def require_bitsandbytes(test_case):
|
||||
"""
|
||||
Decorator for bits and bytes (bnb) dependency
|
||||
|
@ -105,6 +105,7 @@ from .import_utils import (
|
||||
get_torch_version,
|
||||
is_accelerate_available,
|
||||
is_apex_available,
|
||||
is_aqlm_available,
|
||||
is_auto_awq_available,
|
||||
is_auto_gptq_available,
|
||||
is_bitsandbytes_available,
|
||||
|
@ -74,6 +74,7 @@ FSDP_MIN_VERSION = "1.12.0"
|
||||
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
_apex_available = _is_package_available("apex")
|
||||
_aqlm_available = _is_package_available("aqlm")
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
|
||||
_bs4_available = importlib.util.find_spec("bs4") is not None
|
||||
@ -570,6 +571,10 @@ def is_apex_available():
|
||||
return _apex_available
|
||||
|
||||
|
||||
def is_aqlm_available():
|
||||
return _aqlm_available
|
||||
|
||||
|
||||
def is_ninja_available():
|
||||
r"""
|
||||
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the
|
||||
|
@ -38,6 +38,7 @@ class QuantizationMethod(str, Enum):
|
||||
BITS_AND_BYTES = "bitsandbytes"
|
||||
GPTQ = "gptq"
|
||||
AWQ = "awq"
|
||||
AQLM = "aqlm"
|
||||
|
||||
|
||||
class AWQLinearVersion(str, Enum):
|
||||
@ -731,3 +732,63 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
|
||||
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
||||
return loading_attibutes_dict
|
||||
|
||||
|
||||
@dataclass
|
||||
class AqlmConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
This is a wrapper class about `aqlm` parameters.
|
||||
|
||||
Args:
|
||||
in_group_size (`int`, *optional*, defaults to 8):
|
||||
The group size along the input dimension.
|
||||
out_group_size (`int`, *optional*, defaults to 1):
|
||||
The group size along the output dimension. It's recommended to always use 1.
|
||||
num_codebooks (`int`, *optional*, defaults to 1):
|
||||
Number of codebooks for the Additive Quantization procedure.
|
||||
nbits_per_codebook (`int`, *optional*, defaults to 16):
|
||||
Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook.
|
||||
linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*):
|
||||
List of full paths of `nn.Linear` weight parameters that shall not be quantized.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_group_size: int = 8,
|
||||
out_group_size: int = 1,
|
||||
num_codebooks: int = 1,
|
||||
nbits_per_codebook: int = 16,
|
||||
linear_weights_not_to_quantize: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.AQLM
|
||||
self.in_group_size = in_group_size
|
||||
self.out_group_size = out_group_size
|
||||
self.num_codebooks = num_codebooks
|
||||
self.nbits_per_codebook = nbits_per_codebook
|
||||
self.linear_weights_not_to_quantize = linear_weights_not_to_quantize
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
"""
|
||||
if not isinstance(self.in_group_size, int):
|
||||
raise ValueError("in_group_size must be a float")
|
||||
if not isinstance(self.out_group_size, int):
|
||||
raise ValueError("out_group_size must be a float")
|
||||
if not isinstance(self.num_codebooks, int):
|
||||
raise ValueError("num_codebooks must be a float")
|
||||
if not isinstance(self.nbits_per_codebook, int):
|
||||
raise ValueError("nbits_per_codebook must be a float")
|
||||
|
||||
if self.linear_weights_not_to_quantize is not None and not isinstance(
|
||||
self.linear_weights_not_to_quantize, list
|
||||
):
|
||||
raise ValueError("linear_weights_not_to_quantize must be a list of strings")
|
||||
|
||||
if self.linear_weights_not_to_quantize is None:
|
||||
self.linear_weights_not_to_quantize = []
|
||||
|
0
tests/quantization/aqlm_integration/__init__.py
Normal file
0
tests/quantization/aqlm_integration/__init__.py
Normal file
183
tests/quantization/aqlm_integration/test_aqlm.py
Normal file
183
tests/quantization/aqlm_integration/test_aqlm.py
Normal file
@ -0,0 +1,183 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM
|
||||
from transformers.testing_utils import (
|
||||
require_accelerate,
|
||||
require_aqlm,
|
||||
require_torch_gpu,
|
||||
require_torch_multi_gpu,
|
||||
slow,
|
||||
torch_device,
|
||||
)
|
||||
from transformers.utils import is_accelerate_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
class AqlmConfigTest(unittest.TestCase):
|
||||
def test_to_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
|
||||
"""
|
||||
quantization_config = AqlmConfig()
|
||||
config_to_dict = quantization_config.to_dict()
|
||||
|
||||
for key in config_to_dict:
|
||||
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
|
||||
|
||||
def test_from_dict(self):
|
||||
"""
|
||||
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
|
||||
"""
|
||||
dict = {
|
||||
"in_group_size": 32,
|
||||
"num_codebooks": 8,
|
||||
"nbits_per_codebook": 8,
|
||||
"linear_weights_not_to_quantize": ["lm_head.weight"],
|
||||
}
|
||||
quantization_config = AqlmConfig.from_dict(dict)
|
||||
|
||||
self.assertEqual(dict["in_group_size"], quantization_config.in_group_size)
|
||||
self.assertEqual(dict["num_codebooks"], quantization_config.num_codebooks)
|
||||
self.assertEqual(dict["nbits_per_codebook"], quantization_config.nbits_per_codebook)
|
||||
self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
@require_aqlm
|
||||
@require_accelerate
|
||||
class AqlmTest(unittest.TestCase):
|
||||
model_name = "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch"
|
||||
|
||||
input_text = "Hello my name is"
|
||||
|
||||
EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am currently a sophomore and am majoring in Psychology. I am"
|
||||
|
||||
device_map = "cuda"
|
||||
|
||||
# called only once for all test in this class
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
device_map=cls.device_map,
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
def test_quantized_model_conversion(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model has been converted properly
|
||||
"""
|
||||
from aqlm import QuantizedLinear
|
||||
|
||||
from transformers.integrations import replace_with_aqlm_linear
|
||||
|
||||
model_id = "facebook/opt-350m"
|
||||
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
|
||||
quantization_config = AqlmConfig()
|
||||
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
|
||||
nb_linears = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
nb_linears += 1
|
||||
|
||||
model, _ = replace_with_aqlm_linear(model, quantization_config=quantization_config)
|
||||
nb_aqlm_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, QuantizedLinear):
|
||||
nb_aqlm_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears, nb_aqlm_linear)
|
||||
|
||||
# Try with `linear_weights_not_to_quantize`
|
||||
with init_empty_weights():
|
||||
model = OPTForCausalLM(config)
|
||||
|
||||
model, _ = replace_with_aqlm_linear(
|
||||
model, quantization_config=quantization_config, linear_weights_not_to_quantize=["lm_head.weight"]
|
||||
)
|
||||
nb_aqlm_linear = 0
|
||||
for module in model.modules():
|
||||
if isinstance(module, QuantizedLinear):
|
||||
nb_aqlm_linear += 1
|
||||
|
||||
self.assertEqual(nb_linears - 1, nb_aqlm_linear)
|
||||
|
||||
def test_quantized_model(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = self.quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_raise_if_non_quantized(self):
|
||||
model_id = "facebook/opt-125m"
|
||||
quantization_config = AqlmConfig(bits=4)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
|
||||
|
||||
def test_save_pretrained(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly after being saved and loaded
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
|
||||
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
output = model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
@require_torch_multi_gpu
|
||||
def test_quantized_model_multi_gpu(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
|
||||
|
||||
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
Loading…
Reference in New Issue
Block a user