Exllama kernels support for AWQ models (#28634)

* added exllama kernels support for awq models

* doc

* style

* Update src/transformers/modeling_utils.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* refactor

* moved exllama post init to after device dispatching

* bump autoawq version

* added exllama test

* style

* configurable exllama kernels

* copy exllama_config from gptq

* moved exllama version check to post init

* moved to quantization dockerfile

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
Ilyas Moutawwakil 2024-03-05 03:22:48 +01:00 committed by GitHub
parent 81c8191b46
commit 4fc708f98c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 127 additions and 13 deletions

View File

@ -43,7 +43,7 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
# Add autoawq for quantization testing
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl
# When installing in editable mode, `transformers` is not recognized as a package.
# this line must be added in order for python to be aware of transformers.

View File

@ -18,7 +18,11 @@ from ..utils import _LazyModule
_import_structure = {
"aqlm": ["replace_with_aqlm_linear"],
"awq": ["fuse_awq_modules", "replace_with_awq_linear"],
"awq": [
"fuse_awq_modules",
"post_init_awq_exllama_modules",
"replace_with_awq_linear",
],
"bitsandbytes": [
"get_keys_to_not_convert",
"replace_8bit_linear",
@ -82,7 +86,11 @@ _import_structure = {
if TYPE_CHECKING:
from .aqlm import replace_with_aqlm_linear
from .awq import fuse_awq_modules, replace_with_awq_linear
from .awq import (
fuse_awq_modules,
post_init_awq_exllama_modules,
replace_with_awq_linear,
)
from .bitsandbytes import (
get_keys_to_not_convert,
replace_8bit_linear,

View File

@ -15,7 +15,12 @@
from ..activations import ACT2FN
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_torch_available
from ..utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion
from ..utils.quantization_config import (
AwqBackendPackingMethod,
AwqConfig,
AWQLinearVersion,
ExllamaVersion,
)
if is_torch_available():
@ -91,13 +96,30 @@ def replace_with_awq_linear(
)
if backend == AwqBackendPackingMethod.AUTOAWQ:
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
elif backend == AwqBackendPackingMethod.LLMAWQ:
if quantization_config.version == AWQLinearVersion.GEMM:
from awq.modules.linear.gemm import WQLinear_GEMM
target_cls = WQLinear_GEMM
elif quantization_config.version == AWQLinearVersion.GEMV:
from awq.modules.linear.gemv import WQLinear_GEMV
target_cls = WQLinear_GEMV
elif quantization_config.version == AWQLinearVersion.EXLLAMA:
if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
from awq.modules.linear.exllama import WQLinear_Exllama
target_cls = WQLinear_Exllama
elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
target_cls = WQLinear_ExllamaV2
else:
raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}")
else:
raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}")
else:
from awq.quantize.qmodule import WQLinear
if backend == AwqBackendPackingMethod.AUTOAWQ:
target_cls = WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV
else:
target_cls = WQLinear
for name, module in model.named_children():
@ -372,3 +394,28 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
setattr(parent, child_name, fused_attention_layer.to(previous_device))
del q_proj, k_proj, v_proj, o_proj
def post_init_awq_exllama_modules(model, exllama_config):
"""
Runs post init for Exllama layers which performs:
- Weights unpacking, reordering and repacking
- Devices scratch space allocation
"""
if exllama_config["version"] == ExllamaVersion.ONE:
from awq.modules.linear.exllama import exllama_post_init
model = exllama_post_init(model)
elif exllama_config["version"] == ExllamaVersion.TWO:
from awq.modules.linear.exllamav2 import exllamav2_post_init
model = exllamav2_post_init(
model,
max_input_len=exllama_config["max_input_len"],
max_batch_size=exllama_config["max_batch_size"],
)
else:
raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")
return model

View File

@ -23,6 +23,7 @@ if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging
from ..utils.quantization_config import AWQLinearVersion
if is_torch_available():
@ -98,12 +99,22 @@ class AwqQuantizer(HfQuantizer):
model = fuse_awq_modules(model, self.quantization_config)
model._awq_is_fused = True # TODO: consider storing this flag in model.config instead
if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
from ..integrations import post_init_awq_exllama_modules
model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)
@property
def is_serializable(self):
# AWQ through auto-awq has been always serializable, except if the model is fused.
if self.quantization_config.do_fuse:
logger.warning("You cannot save an AWQ model that uses fused modules!")
return False
if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
logger.warning("You cannot save an AWQ model that uses Exllama backend!")
return False
return True
@property

View File

@ -44,6 +44,7 @@ class QuantizationMethod(str, Enum):
class AWQLinearVersion(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
EXLLAMA = "exllama"
@staticmethod
def from_str(version: str):
@ -52,6 +53,8 @@ class AWQLinearVersion(str, Enum):
return AWQLinearVersion.GEMM
elif version == "gemv":
return AWQLinearVersion.GEMV
elif version == "exllama":
return AWQLinearVersion.EXLLAMA
else:
raise ValueError(f"Unknown AWQLinearVersion {version}")
@ -606,7 +609,7 @@ class AwqConfig(QuantizationConfigMixin):
Whether to use zero point quantization.
version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
GEMV is better (e.g. < 8 )
GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels.
backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
that quantize their own models using `llm-awq` library.
@ -620,6 +623,10 @@ class AwqConfig(QuantizationConfigMixin):
The list of modules to not quantize, useful for quantizing models that explicitly require to have
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
exllama_config (`Dict[str, Any]`, *optional*):
You can specify the version of the exllama kernel through the `version` key, the maximum sequence
length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key.
Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset.
"""
def __init__(
@ -633,6 +640,7 @@ class AwqConfig(QuantizationConfigMixin):
fuse_max_seq_len: Optional[int] = None,
modules_to_fuse: Optional[dict] = None,
modules_to_not_convert: Optional[List] = None,
exllama_config: Optional[Dict[str, int]] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.AWQ
@ -644,6 +652,7 @@ class AwqConfig(QuantizationConfigMixin):
self.backend = backend
self.fuse_max_seq_len = fuse_max_seq_len
self.modules_to_not_convert = modules_to_not_convert
self.exllama_config = exllama_config
self.modules_to_fuse = modules_to_fuse
if do_fuse is None:
@ -667,9 +676,9 @@ class AwqConfig(QuantizationConfigMixin):
)
self.version = AWQLinearVersion.from_str(self.version)
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]:
raise ValueError(
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}"
)
if self.backend == AwqBackendPackingMethod.LLMAWQ:
@ -724,9 +733,34 @@ class AwqConfig(QuantizationConfigMixin):
f"Required fields are missing in the fusing mapping, required fields are {required_keys}"
)
if self.version == AWQLinearVersion.EXLLAMA:
awq_version_supports_exllama = False
MIN_AWQ_VERSION = "0.2.0"
if is_auto_awq_available():
awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
MIN_AWQ_VERSION
)
if not awq_version_supports_exllama:
raise ValueError(
f"You current version of `autoawq` does not support exllama backend, "
f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
)
if self.exllama_config is None:
self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8}
else:
if "version" not in self.exllama_config:
raise ValueError("`exllama_config` needs to have a `version` key.")
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
exllama_version = self.exllama_config["version"]
raise ValueError(
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
)
def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict

View File

@ -192,6 +192,20 @@ class AwqTest(unittest.TestCase):
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16)
def test_quantized_model_exllama(self):
"""
Simple test that checks if the quantized model is working properly with exllama backend
"""
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
quantization_config = AwqConfig(version="exllama")
quantized_model = AutoModelForCausalLM.from_pretrained(
self.model_name, quantization_config=quantization_config
).to(torch_device)
output = quantized_model.generate(**input_ids, max_new_tokens=40)
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
def test_quantized_model_no_device_map(self):
"""
Simple test that checks if the quantized model is working properly