remove torch_dtype override (#25894)

* remove torch_dtype override

* style

* Update src/transformers/modeling_utils.py

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>

---------

Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
This commit is contained in:
Marc Sun 2023-08-31 17:38:14 -04:00 committed by GitHub
parent 0f08cd205a
commit ef10dbce5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2578,11 +2578,11 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if quantization_method_from_config == QuantizationMethod.GPTQ:
quantization_config = GPTQConfig.from_dict(config.quantization_config)
config.quantization_config = quantization_config
logger.info(
f"Overriding torch_dtype={torch_dtype} with `torch_dtype=torch.float16` due to "
"requirements of `auto-gptq` to enable model quantization "
)
torch_dtype = torch.float16
if torch_dtype is None:
torch_dtype = torch.float16
else:
logger.info("We suggest you to set `torch_dtype=torch.float16` for better efficiency with GPTQ.")
quantizer = GPTQQuantizer.from_dict(quantization_config.to_dict())
if (