mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 21:00:08 +06:00
[Awq
] Enable the possibility to skip quantization for some target modules (#27950)
* v1 * add docstring * add tests * add awq 0.1.8 * oops * fix test
This commit is contained in:
parent
29e7a1e183
commit
fa21ead73d
@ -56,7 +56,7 @@ RUN python3 -m pip install --no-cache-dir auto-gptq --extra-index-url https://hu
|
||||
RUN python3 -m pip install --no-cache-dir einops
|
||||
|
||||
# Add autoawq for quantization testing
|
||||
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.7/autoawq-0.1.7+cu118-cp38-cp38-linux_x86_64.whl
|
||||
RUN python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.8/autoawq-0.1.8+cu118-cp38-cp38-linux_x86_64.whl
|
||||
|
||||
# For bettertransformer + gptq
|
||||
RUN python3 -m pip install --no-cache-dir git+https://github.com/huggingface/optimum@main#egg=optimum
|
||||
|
@ -3575,6 +3575,9 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if quantization_config is None:
|
||||
quantization_config = AwqConfig.from_dict(config.quantization_config)
|
||||
|
||||
if quantization_config.modules_to_not_convert is not None:
|
||||
modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
|
||||
|
||||
model, has_been_replaced = replace_with_awq_linear(
|
||||
model, quantization_config=quantization_config, modules_to_not_convert=modules_to_not_convert
|
||||
)
|
||||
|
@ -564,6 +564,10 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
The Maximum sequence length to generate when using fusing.
|
||||
modules_to_fuse (`dict`, *optional*, default to `None`):
|
||||
Overwrite the natively supported fusing scheme with the one specified by the users.
|
||||
modules_to_not_convert (`list`, *optional*, default to `None`):
|
||||
The list of modules to not quantize, useful for quantizing models that explicitly require to have
|
||||
some modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
||||
Note you cannot quantize directly with transformers, please refer to `AutoAWQ` documentation for quantizing HF models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -576,6 +580,7 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
do_fuse: Optional[bool] = None,
|
||||
fuse_max_seq_len: Optional[int] = None,
|
||||
modules_to_fuse: Optional[dict] = None,
|
||||
modules_to_not_convert: Optional[List] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.quant_method = QuantizationMethod.AWQ
|
||||
@ -586,6 +591,7 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
self.version = version
|
||||
self.backend = backend
|
||||
self.fuse_max_seq_len = fuse_max_seq_len
|
||||
self.modules_to_not_convert = modules_to_not_convert
|
||||
|
||||
self.modules_to_fuse = modules_to_fuse
|
||||
if do_fuse is None:
|
||||
@ -638,6 +644,19 @@ class AwqConfig(QuantizationConfigMixin):
|
||||
f"You current version of `autoawq` does not support module fusing, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
|
||||
)
|
||||
|
||||
if self.modules_to_not_convert is not None:
|
||||
awq_version_supports_non_conversion = False
|
||||
MIN_AWQ_VERSION = "0.1.8"
|
||||
if is_auto_awq_available():
|
||||
awq_version_supports_non_conversion = version.parse(
|
||||
importlib.metadata.version("autoawq")
|
||||
) >= version.parse(MIN_AWQ_VERSION)
|
||||
|
||||
if not awq_version_supports_non_conversion:
|
||||
raise ValueError(
|
||||
f"You current version of `autoawq` does not support module quantization skipping, please upgrade `autoawq` package to at least {MIN_AWQ_VERSION}."
|
||||
)
|
||||
|
||||
if self.do_fuse and self.modules_to_fuse is not None:
|
||||
required_keys = [
|
||||
"hidden_size",
|
||||
|
@ -88,6 +88,7 @@ class AwqConfigTest(unittest.TestCase):
|
||||
class AwqTest(unittest.TestCase):
|
||||
model_name = "TheBloke/Mistral-7B-v0.1-AWQ"
|
||||
dummy_transformers_model_name = "bigscience/bloom-560m"
|
||||
model_with_no_k_proj_quantized = "hf-internal-testing/opt-125m-awq-no-k-proj"
|
||||
|
||||
input_text = "Hello my name is"
|
||||
|
||||
@ -223,6 +224,24 @@ class AwqTest(unittest.TestCase):
|
||||
|
||||
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
|
||||
|
||||
def test_quantized_model_no_k_proj_quantized(self):
|
||||
"""
|
||||
Simple test that checks if the quantized model is working properly with multiple GPUs
|
||||
"""
|
||||
dummy_input = torch.LongTensor([[0, 1, 0]]).to(torch_device)
|
||||
|
||||
quantized_model = AutoModelForCausalLM.from_pretrained(self.model_with_no_k_proj_quantized).to(torch_device)
|
||||
|
||||
self.assertTrue(isinstance(quantized_model.model.decoder.layers[0].self_attn.k_proj, torch.nn.Linear))
|
||||
self.assertFalse(isinstance(quantized_model.model.decoder.layers[0].self_attn.v_proj, torch.nn.Linear))
|
||||
|
||||
EXPECTED_OUTPUT = torch.LongTensor([[0, 1, 0, 50118, 50118, 133, 248, 12, 134, 16, 10, 372, 2031]]).to(
|
||||
torch_device
|
||||
)
|
||||
|
||||
output = quantized_model.generate(dummy_input, max_new_tokens=10)
|
||||
self.assertTrue((EXPECTED_OUTPUT == output).all())
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
Loading…
Reference in New Issue
Block a user