mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-04 05:10:06 +06:00
fix torch_dtype on awq (#38463)
Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
parent
871901cb3d
commit
2e889c18e1
@ -92,11 +92,11 @@ class AwqQuantizer(HfQuantizer):
|
|||||||
if torch_dtype is None:
|
if torch_dtype is None:
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.")
|
logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.")
|
||||||
elif torch_dtype == torch.bfloat16:
|
elif torch_dtype == torch.bfloat16 and torch.cuda.is_available():
|
||||||
logger.warning("`torch.bfloat16` is not supported for AWQ kernels yet. Casting to `torch.float16`.")
|
logger.warning("`torch.bfloat16` is not supported for AWQ CUDA kernels yet. Casting to `torch.float16`.")
|
||||||
torch_dtype = torch.float16
|
torch_dtype = torch.float16
|
||||||
elif torch_dtype != torch.float16:
|
elif torch_dtype != torch.float16 and torch.cuda.is_available():
|
||||||
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
|
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency on CUDA with AWQ.")
|
||||||
return torch_dtype
|
return torch_dtype
|
||||||
|
|
||||||
def _process_model_before_weight_loading(
|
def _process_model_before_weight_loading(
|
||||||
|
Loading…
Reference in New Issue
Block a user