From 2e889c18e16a4f8d91d1b9f92110522fa16e8c97 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 6 Jun 2025 23:14:00 +0800 Subject: [PATCH] fix torch_dtype on awq (#38463) Signed-off-by: jiqing-feng Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com> --- src/transformers/quantizers/quantizer_awq.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 49837d53128..ee6c6360202 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -92,11 +92,11 @@ class AwqQuantizer(HfQuantizer): if torch_dtype is None: torch_dtype = torch.float16 logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.") - elif torch_dtype == torch.bfloat16: - logger.warning("`torch.bfloat16` is not supported for AWQ kernels yet. Casting to `torch.float16`.") + elif torch_dtype == torch.bfloat16 and torch.cuda.is_available(): + logger.warning("`torch.bfloat16` is not supported for AWQ CUDA kernels yet. Casting to `torch.float16`.") torch_dtype = torch.float16 - elif torch_dtype != torch.float16: - logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.") + elif torch_dtype != torch.float16 and torch.cuda.is_available(): + logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency on CUDA with AWQ.") return torch_dtype def _process_model_before_weight_loading(