diff --git a/src/transformers/utils/quantization_config.py b/src/transformers/utils/quantization_config.py index 51830aab574..35542d3f78f 100644 --- a/src/transformers/utils/quantization_config.py +++ b/src/transformers/utils/quantization_config.py @@ -360,6 +360,8 @@ class GPTQConfig(QuantizationConfigMixin): max_input_length (`int`, *optional*): The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input length. It is specific to the exllama backend with act-order. + cache_block_outputs (`bool`, *optional*, defaults to `True`): + Whether to cache block outputs to reuse as inputs for the succeeding block. """ def __init__( @@ -380,6 +382,7 @@ class GPTQConfig(QuantizationConfigMixin): pad_token_id: Optional[int] = None, disable_exllama: bool = False, max_input_length: Optional[int] = None, + cache_block_outputs: bool = True, **kwargs, ): self.quant_method = QuantizationMethod.GPTQ @@ -399,6 +402,7 @@ class GPTQConfig(QuantizationConfigMixin): self.pad_token_id = pad_token_id self.disable_exllama = disable_exllama self.max_input_length = max_input_length + self.cache_block_outputs = cache_block_outputs self.post_init() def get_loading_attributes(self):