Add exllamav2 better (#27111)

* add_ xllamav2 arg

* add test

* style

* add check

* add doc

* replace by use_exllama_v2

* fix tests

* fix doc

* style

* better condition

* fix logic

* add deprecate msg

* deprecate exllama

* remove disable_exllama from the linter

* remove

* fix warning

* Revert the commits deprecating exllama

* deprecate disable_exllama for use_exllama

* fix

* fix loading attribute

* better handling of args

* remove disable_exllama from init and linter

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* better arg

* fix warning

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* switch to dict

* Apply suggestions from code review

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* style

* nits

* style

* better tests

* style

---------

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Marc Sun 2023-11-01 18:09:21 +01:00 committed by GitHub
parent 239cd0eaa2
commit c9e72f55b2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 181 additions and 28 deletions

View File

@ -223,16 +223,25 @@ model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", de
### Exllama kernels for faster inference
For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `disable_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels. Also, you can perform CPU inference using Auto-GPTQ for Auto-GPTQ version > 0.4.2 by passing `device_map` = "cpu". For CPU inference, you have to pass `disable_exallama = True` in the `GPTQConfig.`
For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `use_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels. Also, you can perform CPU inference using Auto-GPTQ for Auto-GPTQ version > 0.4.2 by passing `device_map` = "cpu". For CPU inference, you have to pass `use_exllama = False` in the `GPTQConfig.`
```py
import torch
gptq_config = GPTQConfig(bits=4, disable_exllama=False)
gptq_config = GPTQConfig(bits=4)
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config=gptq_config)
```
With the release of the exllamav2 kernels, you can get faster inference speed compared to the exllama kernels. You just need to pass `exllama_config={"version": 2}` in [`GPTQConfig`]:
```py
import torch
gptq_config = GPTQConfig(bits=4, exllama_config={"version":2})
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config)
```
Note that only 4-bit models are supported for now. Furthermore, it is recommended to deactivate the exllama kernels if you are finetuning a quantized model with peft.
You can find the benchmark of these kernels [here](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)
#### Fine-tune a quantized model
With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ.

View File

@ -2784,7 +2784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
logger.warning(
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a "
"`quantization_config` attribute and has already quantized weights. However, loading attributes"
" (e.g. disable_exllama, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
" (e.g. use_exllama, exllama_config, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
)
if (
quantization_method_from_args == QuantizationMethod.GPTQ
@ -2811,8 +2811,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
torch_dtype = torch.float16
else:
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict_optimum())
elif quantization_method_from_config == QuantizationMethod.AWQ:
if not torch.cuda.is_available():
raise RuntimeError("GPU is required to run AWQ quantized model.")
@ -3539,7 +3538,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if cls.main_input_name != "input_ids":
raise RuntimeError("We can only quantize pure text model.")
quantizer.quantize_model(model, quantization_config.tokenizer)
config.quantization_config = GPTQConfig.from_dict(quantizer.to_dict())
config.quantization_config = GPTQConfig.from_dict_optimum(quantizer.to_dict())
model._is_quantized_training_enabled = True
if quantization_method_from_config == QuantizationMethod.GPTQ:
model = quantizer.post_init_model(model)

View File

@ -310,6 +310,11 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
return serializable_config_dict
class ExllamaVersion(int, Enum):
ONE = 1
TWO = 2
@dataclass
class GPTQConfig(QuantizationConfigMixin):
"""
@ -355,11 +360,14 @@ class GPTQConfig(QuantizationConfigMixin):
The batch size used when processing the dataset
pad_token_id (`int`, *optional*):
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
disable_exllama (`bool`, *optional*, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
use_exllama (`bool`, *optional*):
Whether to use exllama backend. Defaults to `True` if unset. Only works with `bits` = 4.
max_input_length (`int`, *optional*):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
length. It is specific to the exllama backend with act-order.
exllama_config (`Dict[str, Any]`, *optional*):
The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults
to `{"version": 1}` if unset.
cache_block_outputs (`bool`, *optional*, defaults to `True`):
Whether to cache block outputs to reuse as inputs for the succeeding block.
"""
@ -380,8 +388,9 @@ class GPTQConfig(QuantizationConfigMixin):
module_name_preceding_first_block: Optional[List[str]] = None,
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
use_exllama: Optional[bool] = None,
max_input_length: Optional[int] = None,
exllama_config: Optional[Dict[str, Any]] = None,
cache_block_outputs: bool = True,
**kwargs,
):
@ -400,14 +409,16 @@ class GPTQConfig(QuantizationConfigMixin):
self.module_name_preceding_first_block = module_name_preceding_first_block
self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.disable_exllama = disable_exllama
self.use_exllama = use_exllama
self.max_input_length = max_input_length
self.exllama_config = exllama_config
self.disable_exllama = kwargs.pop("disable_exllama", None)
self.cache_block_outputs = cache_block_outputs
self.post_init()
def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = ["disable_exllama", "use_cuda_fp16", "max_input_length"]
loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict
@ -434,6 +445,73 @@ class GPTQConfig(QuantizationConfigMixin):
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
)
if self.disable_exllama is None and self.use_exllama is None:
# New default behaviour
self.use_exllama = True
elif self.disable_exllama is not None and self.use_exllama is None:
# Follow pattern of old config
logger.warning(
"Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`."
"The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file."
)
self.use_exllama = not self.disable_exllama
elif self.disable_exllama is not None and self.use_exllama is not None:
# Only happens if user explicitly passes in both arguments
raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`")
if self.exllama_config is None:
self.exllama_config = {"version": ExllamaVersion.ONE}
else:
if "version" not in self.exllama_config:
raise ValueError("`exllama_config` needs to have a `version` key.")
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
exllama_version = self.exllama_config["version"]
raise ValueError(
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
)
if self.bits == 4 and self.use_exllama:
if self.exllama_config["version"] == ExllamaVersion.ONE:
logger.info(
"You have activated exllama backend. Note that you can get better inference "
"speed using exllamav2 kernel by setting `exllama_config`."
)
elif self.exllama_config["version"] == ExllamaVersion.TWO:
optimum_version = version.parse(importlib.metadata.version("optimum"))
autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
raise ValueError(
f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
)
def to_dict(self):
config_dict = super().to_dict()
config_dict.pop("disable_exllama", None)
return config_dict
def to_dict_optimum(self):
"""
Get compatible dict for optimum gptq config
"""
quant_dict = self.to_dict()
# make it compatible with optimum config
quant_dict["disable_exllama"] = not self.use_exllama
return quant_dict
@classmethod
def from_dict_optimum(cls, config_dict):
"""
Get compatible class with optimum gptq config dict
"""
if "disable_exllama" in config_dict:
config_dict["use_exllama"] = not config_dict["disable_exllama"]
# switch to None to not trigger the warning
config_dict["disable_exllama"] = None
config = cls(**config_dict)
return config
@dataclass
class AwqConfig(QuantizationConfigMixin):

View File

@ -69,9 +69,9 @@ class GPTQConfigTest(unittest.TestCase):
from optimum.gptq import GPTQQuantizer
config = GPTQConfig(bits=2)
optimum_config = GPTQQuantizer.from_dict(config.to_dict())
optimum_config = GPTQQuantizer.from_dict(config.to_dict_optimum())
self.assertEqual(optimum_config.bits, config.bits)
new_config = GPTQConfig.from_dict(optimum_config.to_dict())
new_config = GPTQConfig.from_dict_optimum(optimum_config.to_dict())
self.assertEqual(optimum_config.bits, new_config.bits)
@ -98,7 +98,7 @@ class GPTQTest(unittest.TestCase):
bits = 4
group_size = 128
desc_act = False
disable_exllama = True
use_exllama = False
dataset = [
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
@ -125,7 +125,7 @@ class GPTQTest(unittest.TestCase):
tokenizer=cls.tokenizer,
group_size=cls.group_size,
desc_act=cls.desc_act,
disable_exllama=cls.disable_exllama,
use_exllama=cls.use_exllama,
)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
@ -147,11 +147,12 @@ class GPTQTest(unittest.TestCase):
def test_device_and_dtype_assignment(self):
r"""
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
Checks also if other models are casted correctly.
"""
# This should work
_ = self.quantized_model.to(0)
if self.device_map is None:
_ = self.quantized_model.to(0)
with self.assertRaises(ValueError):
# Tries with a `dtype``
@ -177,7 +178,8 @@ class GPTQTest(unittest.TestCase):
desc_act=self.desc_act,
group_size=self.group_size,
bits=self.bits,
disable_exllama=self.disable_exllama,
disable_exllama=not self.use_exllama,
disable_exllamav2=True,
)
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)
@ -196,6 +198,9 @@ class GPTQTest(unittest.TestCase):
# Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def check_quantized_layers_type(self, model, value):
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.QUANT_TYPE == value)
def test_generate_quality(self):
"""
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
@ -211,11 +216,13 @@ class GPTQTest(unittest.TestCase):
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
if self.disable_exllama:
if not self.use_exllama:
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname).to(0)
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")
else:
# we need to put it directly to the gpu. Otherwise, we won't be able to initialize the exllama kernel
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": 0})
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")
self.check_inference_correctness(quantized_model_from_saved)
@require_accelerate
@ -234,14 +241,15 @@ class GPTQTest(unittest.TestCase):
"""
with tempfile.TemporaryDirectory() as tmpdirname:
self.quantized_model.save_pretrained(tmpdirname)
if self.disable_exllama:
self.assertEqual(self.quantized_model.config.quantization_config.disable_exllama, True)
if not self.use_exllama:
self.assertEqual(self.quantized_model.config.quantization_config.use_exllama, False)
# we need to put it directly to the gpu. Otherwise, we won't be able to initialize the exllama kernel
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
tmpdirname, quantization_config=GPTQConfig(disable_exllama=False, bits=4), device_map={"": 0}
tmpdirname, quantization_config=GPTQConfig(use_exllama=True, bits=4), device_map={"": 0}
)
self.assertEqual(quantized_model_from_saved.config.quantization_config.disable_exllama, False)
self.assertEqual(quantized_model_from_saved.config.quantization_config.use_exllama, True)
self.assertEqual(quantized_model_from_saved.config.quantization_config.bits, self.bits)
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")
self.check_inference_correctness(quantized_model_from_saved)
@ -255,7 +263,7 @@ class GPTQTestDeviceMap(GPTQTest):
@require_torch_multi_gpu
class GPTQTestDeviceMapExllama(GPTQTest):
device_map = "auto"
disable_exllama = False
use_exllama = True
@slow
@ -281,8 +289,7 @@ class GPTQTestActOrderExllama(unittest.TestCase):
"""
Setup quantized model
"""
cls.quantization_config = GPTQConfig(bits=4, disable_exllama=False, max_input_length=4028)
cls.quantization_config = GPTQConfig(bits=4, max_input_length=4028)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
revision=cls.revision,
@ -308,14 +315,15 @@ class GPTQTestActOrderExllama(unittest.TestCase):
# Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_quantized_layers_type(self):
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE == "exllama")
def test_generate_quality(self):
"""
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
"""
self.check_inference_correctness(self.quantized_model)
# this test will fail until the next release of optimum
@pytest.mark.skip
def test_max_input_length(self):
"""
Test if the max_input_length works. It modifies the maximum input length that of the model that runs with exllama backend.
@ -334,6 +342,65 @@ class GPTQTestActOrderExllama(unittest.TestCase):
self.quantized_model.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
@slow
@require_optimum
@require_auto_gptq
@require_torch_gpu
@require_accelerate
class GPTQTestExllamaV2(unittest.TestCase):
"""
Test GPTQ model with exllamav2 kernel and desc_act=True (also known as act-order).
More information on those arguments here:
https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig
"""
EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is Katie and I am a 20 year")
model_name = "hf-internal-testing/Llama-2-7B-GPTQ"
revision = "gptq-4bit-128g-actorder_True"
input_text = "Hello my name is"
@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""
cls.quantization_config = GPTQConfig(bits=4, exllama_config={"version": 2})
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
revision=cls.revision,
torch_dtype=torch.float16,
device_map={"": 0},
quantization_config=cls.quantization_config,
)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)
def test_quantized_layers_type(self):
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE == "exllamav2")
def check_inference_correctness(self, model):
"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Given that we are operating on small numbers + the testing model is relatively small, we might not get
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
"""
# Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
# Check the exactness of the results
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
# Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
def test_generate_quality(self):
"""
Simple test to check the quality of the model by comapring the the generated tokens with the expected tokens
"""
self.check_inference_correctness(self.quantized_model)
# fail when run all together
@pytest.mark.skip
@require_accelerate