mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Add exllamav2 better (#27111)
* add_ xllamav2 arg * add test * style * add check * add doc * replace by use_exllama_v2 * fix tests * fix doc * style * better condition * fix logic * add deprecate msg * deprecate exllama * remove disable_exllama from the linter * remove * fix warning * Revert the commits deprecating exllama * deprecate disable_exllama for use_exllama * fix * fix loading attribute * better handling of args * remove disable_exllama from init and linter * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * better arg * fix warning * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * switch to dict * Apply suggestions from code review Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> * style * nits * style * better tests * style --------- Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
parent
239cd0eaa2
commit
c9e72f55b2
@ -223,16 +223,25 @@ model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", de
|
||||
|
||||
### Exllama kernels for faster inference
|
||||
|
||||
For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `disable_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels. Also, you can perform CPU inference using Auto-GPTQ for Auto-GPTQ version > 0.4.2 by passing `device_map` = "cpu". For CPU inference, you have to pass `disable_exallama = True` in the `GPTQConfig.`
|
||||
For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `use_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels. Also, you can perform CPU inference using Auto-GPTQ for Auto-GPTQ version > 0.4.2 by passing `device_map` = "cpu". For CPU inference, you have to pass `use_exllama = False` in the `GPTQConfig.`
|
||||
|
||||
```py
|
||||
import torch
|
||||
gptq_config = GPTQConfig(bits=4, disable_exllama=False)
|
||||
gptq_config = GPTQConfig(bits=4)
|
||||
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config=gptq_config)
|
||||
```
|
||||
|
||||
With the release of the exllamav2 kernels, you can get faster inference speed compared to the exllama kernels. You just need to pass `exllama_config={"version": 2}` in [`GPTQConfig`]:
|
||||
|
||||
```py
|
||||
import torch
|
||||
gptq_config = GPTQConfig(bits=4, exllama_config={"version":2})
|
||||
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config)
|
||||
```
|
||||
|
||||
Note that only 4-bit models are supported for now. Furthermore, it is recommended to deactivate the exllama kernels if you are finetuning a quantized model with peft.
|
||||
|
||||
You can find the benchmark of these kernels [here](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)
|
||||
#### Fine-tune a quantized model
|
||||
|
||||
With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ.
|
||||
|
@ -2784,7 +2784,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
logger.warning(
|
||||
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a "
|
||||
"`quantization_config` attribute and has already quantized weights. However, loading attributes"
|
||||
" (e.g. disable_exllama, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
|
||||
" (e.g. use_exllama, exllama_config, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
|
||||
)
|
||||
if (
|
||||
quantization_method_from_args == QuantizationMethod.GPTQ
|
||||
@ -2811,8 +2811,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
|
||||
|
||||
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())
|
||||
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict_optimum())
|
||||
elif quantization_method_from_config == QuantizationMethod.AWQ:
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError("GPU is required to run AWQ quantized model.")
|
||||
@ -3539,7 +3538,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if cls.main_input_name != "input_ids":
|
||||
raise RuntimeError("We can only quantize pure text model.")
|
||||
quantizer.quantize_model(model, quantization_config.tokenizer)
|
||||
config.quantization_config = GPTQConfig.from_dict(quantizer.to_dict())
|
||||
config.quantization_config = GPTQConfig.from_dict_optimum(quantizer.to_dict())
|
||||
model._is_quantized_training_enabled = True
|
||||
if quantization_method_from_config == QuantizationMethod.GPTQ:
|
||||
model = quantizer.post_init_model(model)
|
||||
|
@ -310,6 +310,11 @@ class BitsAndBytesConfig(QuantizationConfigMixin):
|
||||
return serializable_config_dict
|
||||
|
||||
|
||||
class ExllamaVersion(int, Enum):
|
||||
ONE = 1
|
||||
TWO = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class GPTQConfig(QuantizationConfigMixin):
|
||||
"""
|
||||
@ -355,11 +360,14 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
The batch size used when processing the dataset
|
||||
pad_token_id (`int`, *optional*):
|
||||
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
|
||||
disable_exllama (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use exllama backend. Only works with `bits` = 4.
|
||||
use_exllama (`bool`, *optional*):
|
||||
Whether to use exllama backend. Defaults to `True` if unset. Only works with `bits` = 4.
|
||||
max_input_length (`int`, *optional*):
|
||||
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
|
||||
length. It is specific to the exllama backend with act-order.
|
||||
exllama_config (`Dict[str, Any]`, *optional*):
|
||||
The exllama config. You can specify the version of the exllama kernel through the `version` key. Defaults
|
||||
to `{"version": 1}` if unset.
|
||||
cache_block_outputs (`bool`, *optional*, defaults to `True`):
|
||||
Whether to cache block outputs to reuse as inputs for the succeeding block.
|
||||
"""
|
||||
@ -380,8 +388,9 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
module_name_preceding_first_block: Optional[List[str]] = None,
|
||||
batch_size: int = 1,
|
||||
pad_token_id: Optional[int] = None,
|
||||
disable_exllama: bool = False,
|
||||
use_exllama: Optional[bool] = None,
|
||||
max_input_length: Optional[int] = None,
|
||||
exllama_config: Optional[Dict[str, Any]] = None,
|
||||
cache_block_outputs: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
@ -400,14 +409,16 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
self.module_name_preceding_first_block = module_name_preceding_first_block
|
||||
self.batch_size = batch_size
|
||||
self.pad_token_id = pad_token_id
|
||||
self.disable_exllama = disable_exllama
|
||||
self.use_exllama = use_exllama
|
||||
self.max_input_length = max_input_length
|
||||
self.exllama_config = exllama_config
|
||||
self.disable_exllama = kwargs.pop("disable_exllama", None)
|
||||
self.cache_block_outputs = cache_block_outputs
|
||||
self.post_init()
|
||||
|
||||
def get_loading_attributes(self):
|
||||
attibutes_dict = copy.deepcopy(self.__dict__)
|
||||
loading_attibutes = ["disable_exllama", "use_cuda_fp16", "max_input_length"]
|
||||
loading_attibutes = ["disable_exllama", "use_exllama", "exllama_config", "use_cuda_fp16", "max_input_length"]
|
||||
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
||||
return loading_attibutes_dict
|
||||
|
||||
@ -434,6 +445,73 @@ class GPTQConfig(QuantizationConfigMixin):
|
||||
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
|
||||
)
|
||||
|
||||
if self.disable_exllama is None and self.use_exllama is None:
|
||||
# New default behaviour
|
||||
self.use_exllama = True
|
||||
elif self.disable_exllama is not None and self.use_exllama is None:
|
||||
# Follow pattern of old config
|
||||
logger.warning(
|
||||
"Using `disable_exllama` is deprecated and will be removed in version 4.37. Use `use_exllama` instead and specify the version with `exllama_config`."
|
||||
"The value of `use_exllama` will be overwritten by `disable_exllama` passed in `GPTQConfig` or stored in your config file."
|
||||
)
|
||||
self.use_exllama = not self.disable_exllama
|
||||
elif self.disable_exllama is not None and self.use_exllama is not None:
|
||||
# Only happens if user explicitly passes in both arguments
|
||||
raise ValueError("Cannot specify both `disable_exllama` and `use_exllama`. Please use just `use_exllama`")
|
||||
|
||||
if self.exllama_config is None:
|
||||
self.exllama_config = {"version": ExllamaVersion.ONE}
|
||||
else:
|
||||
if "version" not in self.exllama_config:
|
||||
raise ValueError("`exllama_config` needs to have a `version` key.")
|
||||
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
|
||||
exllama_version = self.exllama_config["version"]
|
||||
raise ValueError(
|
||||
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
|
||||
)
|
||||
|
||||
if self.bits == 4 and self.use_exllama:
|
||||
if self.exllama_config["version"] == ExllamaVersion.ONE:
|
||||
logger.info(
|
||||
"You have activated exllama backend. Note that you can get better inference "
|
||||
"speed using exllamav2 kernel by setting `exllama_config`."
|
||||
)
|
||||
elif self.exllama_config["version"] == ExllamaVersion.TWO:
|
||||
optimum_version = version.parse(importlib.metadata.version("optimum"))
|
||||
autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
|
||||
if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
|
||||
raise ValueError(
|
||||
f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
config_dict = super().to_dict()
|
||||
config_dict.pop("disable_exllama", None)
|
||||
return config_dict
|
||||
|
||||
def to_dict_optimum(self):
|
||||
"""
|
||||
Get compatible dict for optimum gptq config
|
||||
"""
|
||||
quant_dict = self.to_dict()
|
||||
# make it compatible with optimum config
|
||||
quant_dict["disable_exllama"] = not self.use_exllama
|
||||
return quant_dict
|
||||
|
||||
@classmethod
|
||||
def from_dict_optimum(cls, config_dict):
|
||||
"""
|
||||
Get compatible class with optimum gptq config dict
|
||||
"""
|
||||
|
||||
if "disable_exllama" in config_dict:
|
||||
config_dict["use_exllama"] = not config_dict["disable_exllama"]
|
||||
# switch to None to not trigger the warning
|
||||
config_dict["disable_exllama"] = None
|
||||
|
||||
config = cls(**config_dict)
|
||||
return config
|
||||
|
||||
|
||||
@dataclass
|
||||
class AwqConfig(QuantizationConfigMixin):
|
||||
|
@ -69,9 +69,9 @@ class GPTQConfigTest(unittest.TestCase):
|
||||
from optimum.gptq import GPTQQuantizer
|
||||
|
||||
config = GPTQConfig(bits=2)
|
||||
optimum_config = GPTQQuantizer.from_dict(config.to_dict())
|
||||
optimum_config = GPTQQuantizer.from_dict(config.to_dict_optimum())
|
||||
self.assertEqual(optimum_config.bits, config.bits)
|
||||
new_config = GPTQConfig.from_dict(optimum_config.to_dict())
|
||||
new_config = GPTQConfig.from_dict_optimum(optimum_config.to_dict())
|
||||
self.assertEqual(optimum_config.bits, new_config.bits)
|
||||
|
||||
|
||||
@ -98,7 +98,7 @@ class GPTQTest(unittest.TestCase):
|
||||
bits = 4
|
||||
group_size = 128
|
||||
desc_act = False
|
||||
disable_exllama = True
|
||||
use_exllama = False
|
||||
|
||||
dataset = [
|
||||
"auto-gptq is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."
|
||||
@ -125,7 +125,7 @@ class GPTQTest(unittest.TestCase):
|
||||
tokenizer=cls.tokenizer,
|
||||
group_size=cls.group_size,
|
||||
desc_act=cls.desc_act,
|
||||
disable_exllama=cls.disable_exllama,
|
||||
use_exllama=cls.use_exllama,
|
||||
)
|
||||
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
@ -147,11 +147,12 @@ class GPTQTest(unittest.TestCase):
|
||||
|
||||
def test_device_and_dtype_assignment(self):
|
||||
r"""
|
||||
Test whether trying to cast (or assigning a device to) a model after converting it in 8-bit will throw an error.
|
||||
Test whether trying to cast (or assigning a device to) a model after quantization will throw an error.
|
||||
Checks also if other models are casted correctly.
|
||||
"""
|
||||
# This should work
|
||||
_ = self.quantized_model.to(0)
|
||||
if self.device_map is None:
|
||||
_ = self.quantized_model.to(0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
# Tries with a `dtype``
|
||||
@ -177,7 +178,8 @@ class GPTQTest(unittest.TestCase):
|
||||
desc_act=self.desc_act,
|
||||
group_size=self.group_size,
|
||||
bits=self.bits,
|
||||
disable_exllama=self.disable_exllama,
|
||||
disable_exllama=not self.use_exllama,
|
||||
disable_exllamav2=True,
|
||||
)
|
||||
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)
|
||||
|
||||
@ -196,6 +198,9 @@ class GPTQTest(unittest.TestCase):
|
||||
# Get the generation
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def check_quantized_layers_type(self, model, value):
|
||||
self.assertTrue(model.transformer.h[0].mlp.dense_4h_to_h.QUANT_TYPE == value)
|
||||
|
||||
def test_generate_quality(self):
|
||||
"""
|
||||
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
|
||||
@ -211,11 +216,13 @@ class GPTQTest(unittest.TestCase):
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
if self.disable_exllama:
|
||||
if not self.use_exllama:
|
||||
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname).to(0)
|
||||
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")
|
||||
else:
|
||||
# we need to put it directly to the gpu. Otherwise, we won't be able to initialize the exllama kernel
|
||||
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(tmpdirname, device_map={"": 0})
|
||||
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")
|
||||
self.check_inference_correctness(quantized_model_from_saved)
|
||||
|
||||
@require_accelerate
|
||||
@ -234,14 +241,15 @@ class GPTQTest(unittest.TestCase):
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
self.quantized_model.save_pretrained(tmpdirname)
|
||||
if self.disable_exllama:
|
||||
self.assertEqual(self.quantized_model.config.quantization_config.disable_exllama, True)
|
||||
if not self.use_exllama:
|
||||
self.assertEqual(self.quantized_model.config.quantization_config.use_exllama, False)
|
||||
# we need to put it directly to the gpu. Otherwise, we won't be able to initialize the exllama kernel
|
||||
quantized_model_from_saved = AutoModelForCausalLM.from_pretrained(
|
||||
tmpdirname, quantization_config=GPTQConfig(disable_exllama=False, bits=4), device_map={"": 0}
|
||||
tmpdirname, quantization_config=GPTQConfig(use_exllama=True, bits=4), device_map={"": 0}
|
||||
)
|
||||
self.assertEqual(quantized_model_from_saved.config.quantization_config.disable_exllama, False)
|
||||
self.assertEqual(quantized_model_from_saved.config.quantization_config.use_exllama, True)
|
||||
self.assertEqual(quantized_model_from_saved.config.quantization_config.bits, self.bits)
|
||||
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")
|
||||
self.check_inference_correctness(quantized_model_from_saved)
|
||||
|
||||
|
||||
@ -255,7 +263,7 @@ class GPTQTestDeviceMap(GPTQTest):
|
||||
@require_torch_multi_gpu
|
||||
class GPTQTestDeviceMapExllama(GPTQTest):
|
||||
device_map = "auto"
|
||||
disable_exllama = False
|
||||
use_exllama = True
|
||||
|
||||
|
||||
@slow
|
||||
@ -281,8 +289,7 @@ class GPTQTestActOrderExllama(unittest.TestCase):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
|
||||
cls.quantization_config = GPTQConfig(bits=4, disable_exllama=False, max_input_length=4028)
|
||||
cls.quantization_config = GPTQConfig(bits=4, max_input_length=4028)
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
revision=cls.revision,
|
||||
@ -308,14 +315,15 @@ class GPTQTestActOrderExllama(unittest.TestCase):
|
||||
# Get the generation
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_quantized_layers_type(self):
|
||||
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE == "exllama")
|
||||
|
||||
def test_generate_quality(self):
|
||||
"""
|
||||
Simple test to check the quality of the model by comparing the generated tokens with the expected tokens
|
||||
"""
|
||||
self.check_inference_correctness(self.quantized_model)
|
||||
|
||||
# this test will fail until the next release of optimum
|
||||
@pytest.mark.skip
|
||||
def test_max_input_length(self):
|
||||
"""
|
||||
Test if the max_input_length works. It modifies the maximum input length that of the model that runs with exllama backend.
|
||||
@ -334,6 +342,65 @@ class GPTQTestActOrderExllama(unittest.TestCase):
|
||||
self.quantized_model.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
|
||||
|
||||
|
||||
@slow
|
||||
@require_optimum
|
||||
@require_auto_gptq
|
||||
@require_torch_gpu
|
||||
@require_accelerate
|
||||
class GPTQTestExllamaV2(unittest.TestCase):
|
||||
"""
|
||||
Test GPTQ model with exllamav2 kernel and desc_act=True (also known as act-order).
|
||||
More information on those arguments here:
|
||||
https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig
|
||||
"""
|
||||
|
||||
EXPECTED_OUTPUTS = set()
|
||||
EXPECTED_OUTPUTS.add("Hello my name is Katie and I am a 20 year")
|
||||
model_name = "hf-internal-testing/Llama-2-7B-GPTQ"
|
||||
revision = "gptq-4bit-128g-actorder_True"
|
||||
input_text = "Hello my name is"
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
"""
|
||||
Setup quantized model
|
||||
"""
|
||||
cls.quantization_config = GPTQConfig(bits=4, exllama_config={"version": 2})
|
||||
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
cls.model_name,
|
||||
revision=cls.revision,
|
||||
torch_dtype=torch.float16,
|
||||
device_map={"": 0},
|
||||
quantization_config=cls.quantization_config,
|
||||
)
|
||||
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)
|
||||
|
||||
def test_quantized_layers_type(self):
|
||||
self.assertTrue(self.quantized_model.model.layers[0].self_attn.k_proj.QUANT_TYPE == "exllamav2")
|
||||
|
||||
def check_inference_correctness(self, model):
|
||||
"""
|
||||
Test the generation quality of the quantized model and see that we are matching the expected output.
|
||||
Given that we are operating on small numbers + the testing model is relatively small, we might not get
|
||||
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
|
||||
"""
|
||||
|
||||
# Check that inference pass works on the model
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
# Check the exactness of the results
|
||||
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
|
||||
|
||||
# Get the generation
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
def test_generate_quality(self):
|
||||
"""
|
||||
Simple test to check the quality of the model by comapring the the generated tokens with the expected tokens
|
||||
"""
|
||||
self.check_inference_correctness(self.quantized_model)
|
||||
|
||||
|
||||
# fail when run all together
|
||||
@pytest.mark.skip
|
||||
@require_accelerate
|
||||
|
Loading…
Reference in New Issue
Block a user