Add support for loading GPTQ models on CPU (#26719)

* Add support for loading GPTQ models on CPU

Right now, we can only load the GPTQ Quantized model on the CUDA
device. The attribute `gptq_supports_cpu` checks if the current
auto_gptq version is the one which has the cpu support for the
model or not.
The larger variants of the model are hard to load/run/trace on
the GPU and that's the rationale behind adding this attribute.

Signed-Off By: Vivek Khandelwal <vivek@nod-labs.com>

* Update quantization.md

* Update quantization.md

* Update quantization.md
This commit is contained in:
Vivek Khandelwal 2023-10-31 19:15:23 +05:30 committed by GitHub
parent 3cd3eaf960
commit 2963e196ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 2 deletions

View File

@ -132,7 +132,7 @@ model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", de
### Exllama kernels for faster inference
For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `disable_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels.
For 4-bit model, you can use the exllama kernels in order to a faster inference speed. It is activated by default. You can change that behavior by passing `disable_exllama` in [`GPTQConfig`]. This will overwrite the quantization config stored in the config. Note that you will only be able to overwrite the attributes related to the kernels. Furthermore, you need to have the entire model on gpus if you want to use exllama kernels. Also, you can perform CPU inference using Auto-GPTQ for Auto-GPTQ version > 0.4.2 by passing `device_map` = "cpu". For CPU inference, you have to pass `disable_exallama = True` in the `GPTQConfig.`
```py
import torch

View File

@ -2788,7 +2788,8 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
quantization_method_from_args == QuantizationMethod.GPTQ
or quantization_method_from_config == QuantizationMethod.GPTQ
):
if not torch.cuda.is_available():
gptq_supports_cpu = version.parse(importlib.metadata.version("auto-gptq")) > version.parse("0.4.2")
if not gptq_supports_cpu and not torch.cuda.is_available():
raise RuntimeError("GPU is required to quantize or run quantize model.")
elif not (is_optimum_available() and is_auto_gptq_available()):
raise ImportError(