diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 49837d53128..ee6c6360202 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -92,11 +92,11 @@ class AwqQuantizer(HfQuantizer): if torch_dtype is None: torch_dtype = torch.float16 logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.") - elif torch_dtype == torch.bfloat16: - logger.warning("`torch.bfloat16` is not supported for AWQ kernels yet. Casting to `torch.float16`.") + elif torch_dtype == torch.bfloat16 and torch.cuda.is_available(): + logger.warning("`torch.bfloat16` is not supported for AWQ CUDA kernels yet. Casting to `torch.float16`.") torch_dtype = torch.float16 - elif torch_dtype != torch.float16: - logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.") + elif torch_dtype != torch.float16 and torch.cuda.is_available(): + logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency on CUDA with AWQ.") return torch_dtype def _process_model_before_weight_loading(