mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Exllama kernels support for AWQ models (#28634)
* added exllama kernels support for awq models * doc * style * Update src/transformers/modeling_utils.py Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> * refactor * moved exllama post init to after device dispatching * bump autoawq version * added exllama test * style * configurable exllama kernels * copy exllama_config from gptq * moved exllama version check to post init * moved to quantization dockerfile --------- Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
81c8191b46
commit
4fc708f98c
@ -43,7 +43,7 @@ RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/opt
|
||||
RUN python3 -m pip install --no-cache-dir aqlm[gpu]==1.0.2
|
||||
|
||||
# Add autoawq for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl
|
||||
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.0/autoawq-0.2.0+cu118-cp38-cp38-linux_x86_64.whl
|
||||
|
||||
# When installing in editable mode, `transformers` is not recognized as a package.
|
||||
# this line must be added in order for python to be aware of transformers.
|
||||
|
@ -18,7 +18,11 @@ from ..utils import _LazyModule
|
||||
|
||||
_import_structure = {
|
||||
"aqlm": ["replace_with_aqlm_linear"],
|
||||
"awq": ["fuse_awq_modules", "replace_with_awq_linear"],
|
||||
"awq": [
|
||||
"fuse_awq_modules",
|
||||
"post_init_awq_exllama_modules",
|
||||
"replace_with_awq_linear",
|
||||
],
|
||||
"bitsandbytes": [
|
||||
"get_keys_to_not_convert",
|
||||
"replace_8bit_linear",
|
||||
@ -82,7 +86,11 @@ _import_structure = {
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .aqlm import replace_with_aqlm_linear
|
||||
from .awq import fuse_awq_modules, replace_with_awq_linear
|
||||
from .awq import (
|
||||
fuse_awq_modules,
|
||||
post_init_awq_exllama_modules,
|
||||
replace_with_awq_linear,
|
||||
)
|
||||
from .bitsandbytes import (
|
||||
get_keys_to_not_convert,
|
||||
replace_8bit_linear,
|
||||
|
@ -15,7 +15,12 @@
|
||||
from ..activations import ACT2FN
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..utils import is_auto_awq_available, is_torch_available
|
||||
from ..utils.quantization_config import AwqBackendPackingMethod, AwqConfig, AWQLinearVersion
|
||||
from ..utils.quantization_config import (
|
||||
AwqBackendPackingMethod,
|
||||
AwqConfig,
|
||||
AWQLinearVersion,
|
||||
ExllamaVersion,
|
||||
)
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -91,13 +96,30 @@ def replace_with_awq_linear(
|
||||
)
|
||||
|
||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
|
||||
elif backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
if quantization_config.version == AWQLinearVersion.GEMM:
|
||||
from awq.modules.linear.gemm import WQLinear_GEMM
|
||||
|
||||
target_cls = WQLinear_GEMM
|
||||
elif quantization_config.version == AWQLinearVersion.GEMV:
|
||||
from awq.modules.linear.gemv import WQLinear_GEMV
|
||||
|
||||
target_cls = WQLinear_GEMV
|
||||
elif quantization_config.version == AWQLinearVersion.EXLLAMA:
|
||||
if quantization_config.exllama_config["version"] == ExllamaVersion.ONE:
|
||||
from awq.modules.linear.exllama import WQLinear_Exllama
|
||||
|
||||
target_cls = WQLinear_Exllama
|
||||
elif quantization_config.exllama_config["version"] == ExllamaVersion.TWO:
|
||||
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
|
||||
|
||||
target_cls = WQLinear_ExllamaV2
|
||||
else:
|
||||
raise ValueError(f"Unrecognized Exllama version: {quantization_config.exllama_config['version']}")
|
||||
else:
|
||||
raise ValueError(f"Unrecognized AWQ version: {quantization_config.version}")
|
||||
else:
|
||||
from awq.quantize.qmodule import WQLinear
|
||||
|
||||
if backend == AwqBackendPackingMethod.AUTOAWQ:
|
||||
target_cls = WQLinear_GEMM if quantization_config.version == AWQLinearVersion.GEMM else WQLinear_GEMV
|
||||
else:
|
||||
target_cls = WQLinear
|
||||
|
||||
for name, module in model.named_children():
|
||||
@ -372,3 +394,28 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
|
||||
setattr(parent, child_name, fused_attention_layer.to(previous_device))
|
||||
|
||||
del q_proj, k_proj, v_proj, o_proj
|
||||
|
||||
|
||||
def post_init_awq_exllama_modules(model, exllama_config):
|
||||
"""
|
||||
Runs post init for Exllama layers which performs:
|
||||
- Weights unpacking, reordering and repacking
|
||||
- Devices scratch space allocation
|
||||
"""
|
||||
|
||||
if exllama_config["version"] == ExllamaVersion.ONE:
|
||||
from awq.modules.linear.exllama import exllama_post_init
|
||||
|
||||
model = exllama_post_init(model)
|
||||
elif exllama_config["version"] == ExllamaVersion.TWO:
|
||||
from awq.modules.linear.exllamav2 import exllamav2_post_init
|
||||
|
||||
model = exllamav2_post_init(
|
||||
model,
|
||||
max_input_len=exllama_config["max_input_len"],
|
||||
max_batch_size=exllama_config["max_batch_size"],
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unrecognized Exllama version: {exllama_config['version']}")
|
||||
|
||||
return model
|
||||
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
|
||||
from ..utils import is_accelerate_available, is_auto_awq_available, is_torch_available, logging
|
||||
from ..utils.quantization_config import AWQLinearVersion
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
@ -98,12 +99,22 @@ class AwqQuantizer(HfQuantizer):
|
||||
model = fuse_awq_modules(model, self.quantization_config)
|
||||
model._awq_is_fused = True # TODO: consider storing this flag in model.config instead
|
||||
|
||||
if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
|
||||
from ..integrations import post_init_awq_exllama_modules
|
||||
|
||||
model = post_init_awq_exllama_modules(model, self.quantization_config.exllama_config)
|
||||
|
||||
@property
|
||||
def is_serializable(self):
|
||||
# AWQ through auto-awq has been always serializable, except if the model is fused.
|
||||
if self.quantization_config.do_fuse:
|
||||
logger.warning("You cannot save an AWQ model that uses fused modules!")
|
||||
return False
|
||||
|
||||
if self.quantization_config.version == AWQLinearVersion.EXLLAMA:
|
||||
logger.warning("You cannot save an AWQ model that uses Exllama backend!")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@property
|
||||
|
@ -44,6 +44,7 @@ class QuantizationMethod(str, Enum):
|
||||
class AWQLinearVersion(str, Enum):
|
||||
GEMM = "gemm"
|
||||
GEMV = "gemv"
|
||||
EXLLAMA = "exllama"
|
||||
|
||||
@staticmethod
|
||||
def from_str(version: str):
|
||||
@ -52,6 +53,8 @@ class AWQLinearVersion(str, Enum):
|
||||
return AWQLinearVersion.GEMM
|
||||
elif version == "gemv":
|
||||
return AWQLinearVersion.GEMV
|
||||
elif version == "exllama":
|
||||
return AWQLinearVersion.EXLLAMA
|
||||
else:
|
||||
raise ValueError(f"Unknown AWQLinearVersion {version}")
|
||||
|
||||
@ -606,7 +609,7 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
Whether to use zero point quantization.
|
||||
version (`AWQLinearVersion`, *optional*, defaults to `AWQLinearVersion.GEMM`):
|
||||
The version of the quantization algorithm to use. GEMM is better for big batch_size (e.g. >= 8) otherwise,
|
||||
GEMV is better (e.g. < 8 )
|
||||
GEMV is better (e.g. < 8 ). GEMM models are compatible with Exllama kernels.
|
||||
backend (`AwqBackendPackingMethod`, *optional*, defaults to `AwqBackendPackingMethod.AUTOAWQ`):
|
||||
The quantization backend. Some models might be quantized using `llm-awq` backend. This is useful for users
|
||||
that quantize their own models using `llm-awq` library.
|
||||
@ -620,6 +623,10 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
||||
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
|
||||
exllama_config (`Dict[str, Any]`, *optional*):
|
||||
You can specify the version of the exllama kernel through the `version` key, the maximum sequence
|
||||
length through the `max_input_len` key, and the maximum batch size through the `max_batch_size` key.
|
||||
Defaults to `{"version": 2, "max_input_len": 2048, "max_batch_size": 8}` if unset.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -633,6 +640,7 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
fuse_max_seq_len: Optional[int] = None,
|
||||
modules_to_fuse: Optional[dict] = None,
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
exllama_config: Optional[Dict[str, int]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.AWQ
|
||||
@ -644,6 +652,7 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
self.backend = backend
|
||||
self.fuse_max_seq_len = fuse_max_seq_len
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
self.exllama_config = exllama_config
|
||||
|
||||
self.modules_to_fuse = modules_to_fuse
|
||||
if do_fuse is None:
|
||||
@ -667,9 +676,9 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
)
|
||||
|
||||
self.version = AWQLinearVersion.from_str(self.version)
|
||||
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV]:
|
||||
if self.version not in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA]:
|
||||
raise ValueError(
|
||||
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV] - not recognized version {self.version}"
|
||||
f"Only supported versions are in [AWQLinearVersion.GEMM, AWQLinearVersion.GEMV, AWQLinearVersion.EXLLAMA] - not recognized version {self.version}"
|
||||
)
|
||||
|
||||
if self.backend == AwqBackendPackingMethod.LLMAWQ:
|
||||
@ -724,9 +733,34 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
f"Required fields are missing in the fusing mapping, required fields are {required_keys}"
|
||||
)
|
||||
|
||||
if self.version == AWQLinearVersion.EXLLAMA:
|
||||
awq_version_supports_exllama = False
|
||||
MIN_AWQ_VERSION = "0.2.0"
|
||||
if is_auto_awq_available():
|
||||
awq_version_supports_exllama = version.parse(importlib.metadata.version("autoawq")) >= version.parse(
|
||||
MIN_AWQ_VERSION
|
||||
)
|
||||
|
||||
if not awq_version_supports_exllama:
|
||||
raise ValueError(
|
||||
f"You current version of `autoawq` does not support exllama backend, "
|
||||
f"please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
|
||||
)
|
||||
|
||||
if self.exllama_config is None:
|
||||
self.exllama_config = {"version": ExllamaVersion.TWO, "max_input_len": 2048, "max_batch_size": 8}
|
||||
else:
|
||||
if "version" not in self.exllama_config:
|
||||
raise ValueError("`exllama_config` needs to have a `version` key.")
|
||||
elif self.exllama_config["version"] not in [ExllamaVersion.ONE, ExllamaVersion.TWO]:
|
||||
exllama_version = self.exllama_config["version"]
|
||||
raise ValueError(
|
||||
f"Only supported versions are in [ExllamaVersion.ONE, ExllamaVersion.TWO] - not recognized version {exllama_version}"
|
||||
)
|
||||
|
||||
def get_loading_attributes(self):
|
||||
attibutes_dict = copy.deepcopy(self.__dict__)
|
||||
loading_attibutes = ["do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
|
||||
loading_attibutes = ["version", "do_fuse", "modules_to_fuse", "fuse_max_seq_len"]
|
||||
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
|
||||
return loading_attibutes_dict
|
||||
|
||||
|
@ -192,6 +192,20 @@ class AwqTest(unittest.TestCase):
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT_BF16)
|
||||
|
||||
def test_quantized_model_exllama(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with exllama backend
|
||||
"""
|
||||
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(torch_device)
|
||||
|
||||
quantization_config = AwqConfig(version="exllama")
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_name, quantization_config=quantization_config
|
||||
).to(torch_device)
|
||||
|
||||
output = quantized_model.generate(**input_ids, max_new_tokens=40)
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_quantized_model_no_device_map(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly
|
||||
|
Loading…
Reference in New Issue
Block a user