AQLM quantizer support (#28928)

* aqlm init

* calibration and dtypes

* docs

* Readme update

* is_aqlm_available

* Simpler link in docs

* Test TODO real reference

* init _import_structure fix

* AqlmConfig autodoc

* integration aqlm

* integrations in tests

* docstring fix

* legacy typing

* Less typings

* More kernels information

* Performance -> Accuracy

* correct tests

* remoced multi-gpu test

* Update docs/source/en/quantization.md

Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

* Brought back multi-gpu tests

* Update src/transformers/integrations/aqlm.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update tests/quantization/aqlm_integration/test_aqlm.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

---------

Co-authored-by: Andrei Panferov <blacksamorez@yandex-team.ru>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Andrei Panferov 2024-02-14 11:25:41 +03:00 committed by GitHub
parent 63ffd56d02
commit 1ecf5f7c98
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 489 additions and 2 deletions

View File

@ -55,6 +55,9 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
# Add einops for additional model testing # Add einops for additional model testing
RUN python3 -m pip install --no-cache-dir einops RUN python3 -m pip install --no-cache-dir einops
# Add aqlm for quantization testing
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.1
# Add autoawq for quantization testing # Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl

View File

@ -26,6 +26,10 @@ Learn how to quantize models in the [Quantization](../quantization) guide.
</Tip> </Tip>
## AqlmConfig
[[autodoc]] AqlmConfig
## AwqConfig ## AwqConfig
[[autodoc]] AwqConfig [[autodoc]] AwqConfig

View File

@ -26,6 +26,34 @@ Interested in adding a new quantization method to Transformers? Read the [HfQuan
</Tip> </Tip>
## AQLM
Try AQLM on [Google Colab](https://colab.research.google.com/drive/1-xZmBRXT5Fm3Ghn4Mwa2KRypORXb855X?usp=sharing)!
Additive Quantization of Language Models ([AQLM](https://arxiv.org/abs/2401.06118)) is a Large Language Models compression method. It quantizes multiple weights together and take advantage of interdependencies between them. AQLM represents groups of 8-16 weights as a sum of multiple vector codes.
Inference support for AQLM is realised in the `aqlm` library. Make sure to install it to run the models (note aqlm works only with python>=3.10):
```bash
pip install aqlm[gpu,cpu]
```
The library provides efficient kernels for both GPU and CPU inference.
The instructions on how to quantize models yourself, as well as all the relevant code can be found in the corresponding GitHub [repository](https://github.com/Vahe1994/AQLM).
### AQLM configurations
AQLM quantization setpus vary mainly on the number of codebooks used as well as codebook sizes in bits. The most popular setups, as well as inference kernels they support are:
| Kernel | Number of codebooks | Codebook size, bits | Notation | Accuracy | Speedup | Fast GPU inference | Fast CPU inference |
|---|---------------------|---------------------|----------|-------------|-------------|--------------------|--------------------|
| Triton | K | N | KxN | - | Up to ~0.7x | ✅ | ❌ |
| CUDA | 1 | 16 | 1x16 | Best | Up to ~1.3x | ✅ | ❌ |
| CUDA | 2 | 8 | 2x8 | OK | Up to ~3.0x | ✅ | ❌ |
| Numba | K | 8 | Kx8 | Good | Up to ~4.0x | ❌ | ✅ |
## AWQ ## AWQ
<Tip> <Tip>

View File

@ -1087,7 +1087,7 @@ _import_structure = {
"is_vision_available", "is_vision_available",
"logging", "logging",
], ],
"utils.quantization_config": ["AwqConfig", "BitsAndBytesConfig", "GPTQConfig"], "utils.quantization_config": ["AqlmConfig", "AwqConfig", "BitsAndBytesConfig", "GPTQConfig"],
} }
# sentencepiece-backed objects # sentencepiece-backed objects
@ -5845,7 +5845,7 @@ if TYPE_CHECKING:
) )
# bitsandbytes config # bitsandbytes config
from .utils.quantization_config import AwqConfig, BitsAndBytesConfig, GPTQConfig from .utils.quantization_config import AqlmConfig, AwqConfig, BitsAndBytesConfig, GPTQConfig
try: try:
if not is_sentencepiece_available(): if not is_sentencepiece_available():

View File

@ -17,6 +17,7 @@ from ..utils import _LazyModule
_import_structure = { _import_structure = {
"aqlm": ["replace_with_aqlm_linear"],
"awq": ["fuse_awq_modules", "replace_with_awq_linear"], "awq": ["fuse_awq_modules", "replace_with_awq_linear"],
"bitsandbytes": [ "bitsandbytes": [
"get_keys_to_not_convert", "get_keys_to_not_convert",
@ -80,6 +81,7 @@ _import_structure = {
} }
if TYPE_CHECKING: if TYPE_CHECKING:
from .aqlm import replace_with_aqlm_linear
from .awq import fuse_awq_modules, replace_with_awq_linear from .awq import fuse_awq_modules, replace_with_awq_linear
from .bitsandbytes import ( from .bitsandbytes import (
get_keys_to_not_convert, get_keys_to_not_convert,

View File

@ -0,0 +1,99 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"AQLM (Additive Quantization of Language Model) integration file"
from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available
if is_torch_available():
import torch.nn as nn
def replace_with_aqlm_linear(
model,
quantization_config=None,
linear_weights_not_to_quantize=None,
current_key_name=None,
has_been_replaced=False,
):
"""
Public method that recursively replaces the Linear layers of the given model with AQLM quantized layers.
`accelerate` is needed to use this method. Returns the converted model and a boolean that indicates if the
conversion has been successfull or not.
Args:
model (`torch.nn.Module`):
The model to convert, can be any `torch.nn.Module` instance.
quantization_config (`AqlmConfig`):
The quantization config object that contains the quantization parameters.
linear_weights_not_to_quantize (`list[str]`, *optional*):
A list of nn.Linear weights to not convert. If a parameter path is in the list (e.g. `lm_head.weight`), the corresponding module will not be
converted.
current_key_name (`list`, *optional*):
A list that contains the current key name. This is used for recursion and should not be passed by the user.
has_been_replaced (`bool`, *optional*):
A boolean that indicates if the conversion has been successful or not. This is used for recursion and
should not be passed by the user.
"""
if not is_aqlm_available():
raise ValueError("AQLM is not available. Please install it with `pip install aqlm[cpu,gpu]`")
if not is_accelerate_available():
raise ValueError("AQLM requires Accelerate to be installed: `pip install accelerate`")
if linear_weights_not_to_quantize is None:
linear_weights_not_to_quantize = []
from accelerate import init_empty_weights
from aqlm import QuantizedLinear
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
current_key_name.append(name)
if isinstance(module, nn.Linear):
# Check if the current key is not in the `linear_weights_not_to_quantize`
if ".".join(current_key_name) + ".weight" not in linear_weights_not_to_quantize:
with init_empty_weights():
in_features = module.in_features
out_features = module.out_features
model._modules[name] = QuantizedLinear(
in_features,
out_features,
bias=module.bias is not None,
in_group_size=quantization_config.in_group_size,
out_group_size=quantization_config.out_group_size,
num_codebooks=quantization_config.num_codebooks,
nbits_per_codebook=quantization_config.nbits_per_codebook,
)
has_been_replaced = True
# Store the module class in case we need to transpose the weight later
model._modules[name].source_cls = type(module)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
if len(list(module.children())) > 0:
_, has_been_replaced = replace_with_aqlm_linear(
module,
quantization_config=quantization_config,
linear_weights_not_to_quantize=linear_weights_not_to_quantize,
current_key_name=current_key_name,
has_been_replaced=has_been_replaced,
)
# Remove the last key for recursion
current_key_name.pop(-1)
return model, has_been_replaced

View File

@ -16,12 +16,14 @@ from typing import Dict, Optional, Union
from ..models.auto.configuration_auto import AutoConfig from ..models.auto.configuration_auto import AutoConfig
from ..utils.quantization_config import ( from ..utils.quantization_config import (
AqlmConfig,
AwqConfig, AwqConfig,
BitsAndBytesConfig, BitsAndBytesConfig,
GPTQConfig, GPTQConfig,
QuantizationConfigMixin, QuantizationConfigMixin,
QuantizationMethod, QuantizationMethod,
) )
from .quantizer_aqlm import AqlmHfQuantizer
from .quantizer_awq import AwqQuantizer from .quantizer_awq import AwqQuantizer
from .quantizer_bnb_4bit import Bnb4BitHfQuantizer from .quantizer_bnb_4bit import Bnb4BitHfQuantizer
from .quantizer_bnb_8bit import Bnb8BitHfQuantizer from .quantizer_bnb_8bit import Bnb8BitHfQuantizer
@ -33,6 +35,7 @@ AUTO_QUANTIZER_MAPPING = {
"bitsandbytes_4bit": Bnb4BitHfQuantizer, "bitsandbytes_4bit": Bnb4BitHfQuantizer,
"bitsandbytes_8bit": Bnb8BitHfQuantizer, "bitsandbytes_8bit": Bnb8BitHfQuantizer,
"gptq": GptqHfQuantizer, "gptq": GptqHfQuantizer,
"aqlm": AqlmHfQuantizer,
} }
AUTO_QUANTIZATION_CONFIG_MAPPING = { AUTO_QUANTIZATION_CONFIG_MAPPING = {
@ -40,6 +43,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig, "bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig, "bitsandbytes_8bit": BitsAndBytesConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"aqlm": AqlmConfig,
} }

View File

@ -0,0 +1,89 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING, Optional
from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..integrations import replace_with_aqlm_linear
from ..utils import is_accelerate_available, is_aqlm_available, is_torch_available, logging
from ..utils.quantization_config import QuantizationConfigMixin
if is_torch_available():
import torch
logger = logging.get_logger(__name__)
class AqlmHfQuantizer(HfQuantizer):
"""
Quantizer of the AQLM method. Enables the loading of prequantized models.
"""
requires_calibration = True
required_packages = ["aqlm"]
optimum_quantizer = None
def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
super().__init__(quantization_config, **kwargs)
self.quantization_config = quantization_config
def validate_environment(self, *args, **kwargs):
if not is_accelerate_available():
raise ImportError("Using `aqlm` quantization requires Accelerate: `pip install accelerate`")
if not is_aqlm_available():
raise ImportError("Using `aqlm` quantization requires AQLM: `pip install aqlm[gpu,cpu]`")
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
if torch_dtype is None:
if torch.cuda.is_available():
torch_dtype = torch.float16
logger.info(
"CUDA available. Assuming AQLM inference on GPU and loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually."
)
else:
torch_dtype = torch.float32
logger.info(
"CUDA is unavailable. Assuming AQLM inference on CPU and loading the model in `torch.float32`. To overwrite it, set `torch_dtype` manually."
)
return torch_dtype
def _process_model_before_weight_loading(
self,
model: "PreTrainedModel",
**kwargs,
):
replace_with_aqlm_linear(
model,
quantization_config=self.quantization_config,
linear_weights_not_to_quantize=self.quantization_config.linear_weights_not_to_quantize,
)
model.config.quantization_config = self.quantization_config
def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
model._is_quantized_training_enabled = False
return model
@property
def is_trainable(self, model: Optional["PreTrainedModel"] = None):
return False
@property
def is_serializable(self):
return True

View File

@ -53,6 +53,7 @@ from .integrations.deepspeed import is_deepspeed_available
from .utils import ( from .utils import (
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_aqlm_available,
is_auto_awq_available, is_auto_awq_available,
is_auto_gptq_available, is_auto_gptq_available,
is_bitsandbytes_available, is_bitsandbytes_available,
@ -956,6 +957,13 @@ def require_apex(test_case):
return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case) return unittest.skipUnless(is_apex_available(), "test requires apex")(test_case)
def require_aqlm(test_case):
"""
Decorator marking a test that requires aqlm
"""
return unittest.skipUnless(is_aqlm_available(), "test requires aqlm")(test_case)
def require_bitsandbytes(test_case): def require_bitsandbytes(test_case):
""" """
Decorator for bits and bytes (bnb) dependency Decorator for bits and bytes (bnb) dependency

View File

@ -105,6 +105,7 @@ from .import_utils import (
get_torch_version, get_torch_version,
is_accelerate_available, is_accelerate_available,
is_apex_available, is_apex_available,
is_aqlm_available,
is_auto_awq_available, is_auto_awq_available,
is_auto_gptq_available, is_auto_gptq_available,
is_bitsandbytes_available, is_bitsandbytes_available,

View File

@ -74,6 +74,7 @@ FSDP_MIN_VERSION = "1.12.0"
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) _accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_apex_available = _is_package_available("apex") _apex_available = _is_package_available("apex")
_aqlm_available = _is_package_available("aqlm")
_bitsandbytes_available = _is_package_available("bitsandbytes") _bitsandbytes_available = _is_package_available("bitsandbytes")
# `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed. # `importlib.metadata.version` doesn't work with `bs4` but `beautifulsoup4`. For `importlib.util.find_spec`, reversed.
_bs4_available = importlib.util.find_spec("bs4") is not None _bs4_available = importlib.util.find_spec("bs4") is not None
@ -570,6 +571,10 @@ def is_apex_available():
return _apex_available return _apex_available
def is_aqlm_available():
return _aqlm_available
def is_ninja_available(): def is_ninja_available():
r""" r"""
Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the Code comes from *torch.utils.cpp_extension.is_ninja_available()*. Returns `True` if the

View File

@ -38,6 +38,7 @@ class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes" BITS_AND_BYTES = "bitsandbytes"
GPTQ = "gptq" GPTQ = "gptq"
AWQ = "awq" AWQ = "awq"
AQLM = "aqlm"
class AWQLinearVersion(str, Enum): class AWQLinearVersion(str, Enum):
@ -731,3 +732,63 @@ class AwqConfig(QuantizationConfigMixin):
loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"] loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes} loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict return loading_attibutes_dict
@dataclass
class AqlmConfig(QuantizationConfigMixin):
"""
This is a wrapper class about `aqlm` parameters.
Args:
in_group_size (`int`, *optional*, defaults to 8):
The group size along the input dimension.
out_group_size (`int`, *optional*, defaults to 1):
The group size along the output dimension. It's recommended to always use 1.
num_codebooks (`int`, *optional*, defaults to 1):
Number of codebooks for the Additive Quantization procedure.
nbits_per_codebook (`int`, *optional*, defaults to 16):
Number of bits encoding a single codebook vector. Codebooks size is 2**nbits_per_codebook.
linear_weights_not_to_quantize (`Optional[List[str]]`, *optional*):
List of full paths of `nn.Linear` weight parameters that shall not be quantized.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
"""
def __init__(
self,
in_group_size: int = 8,
out_group_size: int = 1,
num_codebooks: int = 1,
nbits_per_codebook: int = 16,
linear_weights_not_to_quantize: Optional[List[str]] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AQLM
self.in_group_size = in_group_size
self.out_group_size = out_group_size
self.num_codebooks = num_codebooks
self.nbits_per_codebook = nbits_per_codebook
self.linear_weights_not_to_quantize = linear_weights_not_to_quantize
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not isinstance(self.in_group_size, int):
raise ValueError("in_group_size must be a float")
if not isinstance(self.out_group_size, int):
raise ValueError("out_group_size must be a float")
if not isinstance(self.num_codebooks, int):
raise ValueError("num_codebooks must be a float")
if not isinstance(self.nbits_per_codebook, int):
raise ValueError("nbits_per_codebook must be a float")
if self.linear_weights_not_to_quantize is not None and not isinstance(
self.linear_weights_not_to_quantize, list
):
raise ValueError("linear_weights_not_to_quantize must be a list of strings")
if self.linear_weights_not_to_quantize is None:
self.linear_weights_not_to_quantize = []

View File

@ -0,0 +1,183 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
from transformers import AqlmConfig, AutoConfig, AutoModelForCausalLM, AutoTokenizer, OPTForCausalLM
from transformers.testing_utils import (
require_accelerate,
require_aqlm,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
from transformers.utils import is_accelerate_available, is_torch_available
if is_torch_available():
import torch
if is_accelerate_available():
from accelerate import init_empty_weights
@require_torch_gpu
class AqlmConfigTest(unittest.TestCase):
def test_to_dict(self):
"""
Simple test that checks if one uses a config and converts it to a dict, the dict is the same as the config object
"""
quantization_config = AqlmConfig()
config_to_dict = quantization_config.to_dict()
for key in config_to_dict:
self.assertEqual(getattr(quantization_config, key), config_to_dict[key])
def test_from_dict(self):
"""
Simple test that checks if one uses a dict and converts it to a config object, the config object is the same as the dict
"""
dict = {
"in_group_size": 32,
"num_codebooks": 8,
"nbits_per_codebook": 8,
"linear_weights_not_to_quantize": ["lm_head.weight"],
}
quantization_config = AqlmConfig.from_dict(dict)
self.assertEqual(dict["in_group_size"], quantization_config.in_group_size)
self.assertEqual(dict["num_codebooks"], quantization_config.num_codebooks)
self.assertEqual(dict["nbits_per_codebook"], quantization_config.nbits_per_codebook)
self.assertEqual(dict["linear_weights_not_to_quantize"], quantization_config.linear_weights_not_to_quantize)
@slow
@require_torch_gpu
@require_aqlm
@require_accelerate
class AqlmTest(unittest.TestCase):
model_name = "BlackSamorez/Mixtral-8x7b-AQLM-2Bit-1x16-hf-test-dispatch"
input_text = "Hello my name is"
EXPECTED_OUTPUT = "Hello my name is Katie and I am a 20 year old student at the University of North Carolina at Chapel Hill. I am currently a sophomore and am majoring in Psychology. I am"
device_map = "cuda"
# called only once for all test in this class
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
device_map=cls.device_map,
)
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()
gc.collect()
def test_quantized_model_conversion(self):
"""
Simple test that checks if the quantized model has been converted properly
"""
from aqlm import QuantizedLinear
from transformers.integrations import replace_with_aqlm_linear
model_id = "facebook/opt-350m"
config = AutoConfig.from_pretrained(model_id, revision="cb32f77e905cccbca1d970436fb0f5e6b58ee3c5")
quantization_config = AqlmConfig()
with init_empty_weights():
model = OPTForCausalLM(config)
nb_linears = 0
for module in model.modules():
if isinstance(module, torch.nn.Linear):
nb_linears += 1
model, _ = replace_with_aqlm_linear(model, quantization_config=quantization_config)
nb_aqlm_linear = 0
for module in model.modules():
if isinstance(module, QuantizedLinear):
nb_aqlm_linear += 1
self.assertEqual(nb_linears, nb_aqlm_linear)
# Try with `linear_weights_not_to_quantize`
with init_empty_weights():
model = OPTForCausalLM(config)
model, _ = replace_with_aqlm_linear(
model, quantization_config=quantization_config, linear_weights_not_to_quantize=["lm_head.weight"]
)
nb_aqlm_linear = 0
for module in model.modules():
if isinstance(module, QuantizedLinear):
nb_aqlm_linear += 1
self.assertEqual(nb_linears - 1, nb_aqlm_linear)
def test_quantized_model(self):
"""
Simple test that checks if the quantized model is working properly
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = self.quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_raise_if_non_quantized(self):
model_id = "facebook/opt-125m"
quantization_config = AqlmConfig(bits=4)
with self.assertRaises(ValueError):
_ = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
def test_save_pretrained(self):
"""
Simple test that checks if the quantized model is working properly after being saved and loaded
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
model = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map=self.device_map)
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
output = model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
@require_torch_multi_gpu
def test_quantized_model_multi_gpu(self):
"""
Simple test that checks if the quantized model is working properly with multiple GPUs
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_name, device_map="auto")
self.assertTrue(set(quantized_model.hf_device_map.values()) == {0, 1})
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)