mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-03 03:31:05 +06:00
enable/disable compile for quants methods (#36519)
* disable compile for most quants methods * fix * Update src/transformers/generation/configuration_utils.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update tests/quantization/bnb/test_mixed_int8.py Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> * Update src/transformers/generation/configuration_utils.py Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com> * changes from joao suggestions --------- Co-authored-by: Matthew Douglas <38992547+matthewdouglas@users.noreply.github.com> Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
This commit is contained in:
parent
c53d53da89
commit
9e94801146
@ -379,8 +379,7 @@ class GenerationConfig(PushToHubMixin):
|
||||
If using a static cache, this controls how `generate` will `compile` the forward pass for performance
|
||||
gains.
|
||||
|
||||
disable_compile (`bool`, *optional*): Whether to disable the compilation of the forward pass when using 'statis' cache
|
||||
implementation.
|
||||
disable_compile (`bool`, *optional*): Whether to disable the automatic compilation of the forward pass. Automatic compilation happens when specific criteria are met, including using a compileable cache. Please open an issue if you find the need to use this flag.
|
||||
|
||||
> Wild card
|
||||
|
||||
|
@ -1613,7 +1613,6 @@ class GenerationMixin:
|
||||
model_kwargs = generation_config.update(**kwargs)
|
||||
else:
|
||||
model_kwargs = kwargs
|
||||
|
||||
return generation_config, model_kwargs
|
||||
|
||||
def _get_initial_cache_position(self, input_ids, model_kwargs):
|
||||
@ -3281,7 +3280,9 @@ class GenerationMixin:
|
||||
model_forward = self.__call__
|
||||
if isinstance(model_kwargs.get("past_key_values"), Cache):
|
||||
is_compileable = model_kwargs["past_key_values"].is_compileable and self._supports_static_cache
|
||||
is_compileable = is_compileable and not self.generation_config.disable_compile
|
||||
if getattr(self, "hf_quantizer", None) is not None:
|
||||
is_compileable &= self.hf_quantizer.is_compileable
|
||||
is_compileable = is_compileable and not generation_config.disable_compile
|
||||
if is_compileable and (
|
||||
self.device.type == "cuda" or generation_config.compile_config._compile_all_devices
|
||||
):
|
||||
|
@ -271,6 +271,11 @@ class HfQuantizer(ABC):
|
||||
"""Flag indicating whether the quantized model can carry out quantization aware training"""
|
||||
return False
|
||||
|
||||
@property
|
||||
def is_compileable(self) -> bool:
|
||||
"""Flag indicating whether the quantized model can be compiled"""
|
||||
return False
|
||||
|
||||
@abstractmethod
|
||||
def _process_model_before_weight_loading(self, model, **kwargs): ...
|
||||
|
||||
|
@ -243,3 +243,7 @@ class TorchAoHfQuantizer(HfQuantizer):
|
||||
"int8_dynamic_activation_int8_weight",
|
||||
]
|
||||
return self.quantization_config.quant_type in supported_quant_types_for_training
|
||||
|
||||
@property
|
||||
def is_compileable(self) -> bool:
|
||||
return True
|
||||
|
@ -771,3 +771,36 @@ class Bnb4BitTestBasicConfigTest(unittest.TestCase):
|
||||
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
|
||||
with self.assertRaisesRegex(ValueError, "load_in_4bit and load_in_8bit are both True"):
|
||||
quantization_config.load_in_8bit = True
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch_gpu_if_bnb_not_multi_backend_enabled
|
||||
@slow
|
||||
@apply_skip_if_not_implemented
|
||||
class Bnb4bitCompile(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
input_text = "Hello my name is"
|
||||
|
||||
def setUp(self):
|
||||
# Models and tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.model_4bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_4bit=True)
|
||||
|
||||
def test_generate_compile(self):
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
# if nothing is set, compile will be disabled for bnb
|
||||
self.model_4bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
|
||||
max_new_tokens=10,
|
||||
cache_implementation="static",
|
||||
)
|
||||
with self.assertRaises(Exception):
|
||||
# overwrite property
|
||||
object.__setattr__(self.model_4bit.hf_quantizer, "is_compileable", True)
|
||||
self.model_4bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(self.model_4bit.device),
|
||||
max_new_tokens=10,
|
||||
cache_implementation="static",
|
||||
)
|
||||
|
@ -966,3 +966,37 @@ class MixedInt8LlamaTest(MixedInt8Test):
|
||||
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(torch_device), max_new_tokens=10)
|
||||
|
||||
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
|
||||
|
||||
|
||||
@require_bitsandbytes
|
||||
@require_accelerate
|
||||
@require_torch
|
||||
@require_torch_gpu_if_bnb_not_multi_backend_enabled
|
||||
@slow
|
||||
@apply_skip_if_not_implemented
|
||||
class Bnb8bitCompile(unittest.TestCase):
|
||||
model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM"
|
||||
input_text = "Hello my name is"
|
||||
|
||||
def setUp(self):
|
||||
# Models and tokenizer
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
||||
self.model_8bit = AutoModelForCausalLM.from_pretrained(self.model_name, load_in_8bit=True)
|
||||
|
||||
def test_generate_compile(self):
|
||||
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
|
||||
|
||||
# if nothing is set, compile will be disabled for bnb
|
||||
self.model_8bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
|
||||
max_new_tokens=10,
|
||||
cache_implementation="static",
|
||||
)
|
||||
|
||||
with self.assertRaises(Exception):
|
||||
object.__setattr__(self.model_8bit.hf_quantizer, "is_compileable", True)
|
||||
self.model_8bit.generate(
|
||||
input_ids=encoded_input["input_ids"].to(self.model_8bit.device),
|
||||
max_new_tokens=10,
|
||||
cache_implementation="static",
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user