Added cache_block_outputs option to enable GPTQ for non-regular models (#27032)

* Added cache_block_outputs option to enable GPTQ for non-regular models

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/utils/quantization_config.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Fixed style

* Update src/transformers/utils/quantization_config.py

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>
This commit is contained in:
Alexander Kozlov 2023-11-01 18:37:19 +04:00 committed by GitHub
parent 037fb7d0e1
commit f9b4bea0a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -360,6 +360,8 @@ class GPTQConfig(QuantizationConfigMixin):
max_input_length (`int`, *optional*):
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
length. It is specific to the exllama backend with act-order.
cache_block_outputs (`bool`, *optional*, defaults to `True`):
Whether to cache block outputs to reuse as inputs for the succeeding block.
"""
def __init__(
@ -380,6 +382,7 @@ class GPTQConfig(QuantizationConfigMixin):
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
max_input_length: Optional[int] = None,
cache_block_outputs: bool = True,
**kwargs,
):
self.quant_method = QuantizationMethod.GPTQ
@ -399,6 +402,7 @@ class GPTQConfig(QuantizationConfigMixin):
self.pad_token_id = pad_token_id
self.disable_exllama = disable_exllama
self.max_input_length = max_input_length
self.cache_block_outputs = cache_block_outputs
self.post_init()
def get_loading_attributes(self):