fix torch_dtype on awq (#38463)

Signed-off-by: jiqing-feng <jiqing.feng@intel.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
This commit is contained in:
jiqing-feng 2025-06-06 23:14:00 +08:00 committed by GitHub
parent 871901cb3d
commit 2e889c18e1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -92,11 +92,11 @@ class AwqQuantizer(HfQuantizer):
if torch_dtype is None:
torch_dtype = torch.float16
logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.")
elif torch_dtype == torch.bfloat16:
logger.warning("`torch.bfloat16` is not supported for AWQ kernels yet. Casting to `torch.float16`.")
elif torch_dtype == torch.bfloat16 and torch.cuda.is_available():
logger.warning("`torch.bfloat16` is not supported for AWQ CUDA kernels yet. Casting to `torch.float16`.")
torch_dtype = torch.float16
elif torch_dtype != torch.float16:
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
elif torch_dtype != torch.float16 and torch.cuda.is_available():
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency on CUDA with AWQ.")
return torch_dtype
def _process_model_before_weight_loading(