mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
remove torch_dtype override (#25894)
* remove torch_dtype override * style * Update src/transformers/modeling_utils.py Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
parent
0f08cd205a
commit
ef10dbce5c
@ -2578,11 +2578,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
|
||||
if quantization_method_from_config == QuantizationMethod.GPTQ:
|
||||
quantization_config = GPTQConfig.from_dict(config.quantization_config)
|
||||
config.quantization_config = quantization_config
|
||||
logger.info(
|
||||
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
|
||||
"requirements of `auto-gptq` to enable model quantization "
|
||||
)
|
||||
torch_dtype = torch.float16
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.float16
|
||||
else:
|
||||
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
|
||||
|
||||
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())
|
||||
|
||||
if (
|
||||
|
Loading…
Reference in New Issue
Block a user