diff --git a/src/transformers/quantizers/base.py b/src/transformers/quantizers/base.py index 08cd1e2e7d3..317bce07592 100644 --- a/src/transformers/quantizers/base.py +++ b/src/transformers/quantizers/base.py @@ -263,14 +263,17 @@ class HfQuantizer(ABC): model: "PreTrainedModel", skip_modules: Optional[List[str]] = None, keep_in_fp32_modules: Optional[List[str]] = None, + add_default_skips: bool = False, ): from ..integrations import get_keys_to_not_convert - modules_to_not_convert = [] - if skip_modules is None: + if skip_modules is None or add_default_skips: modules_to_not_convert = get_keys_to_not_convert(model) else: - modules_to_not_convert = skip_modules + modules_to_not_convert = [] + + if skip_modules is not None: + modules_to_not_convert.extend(skip_modules) if keep_in_fp32_modules is not None: modules_to_not_convert.extend(keep_in_fp32_modules) diff --git a/src/transformers/quantizers/quantizer_awq.py b/src/transformers/quantizers/quantizer_awq.py index 28460ac38ed..8e63e2f5bf6 100644 --- a/src/transformers/quantizers/quantizer_awq.py +++ b/src/transformers/quantizers/quantizer_awq.py @@ -92,6 +92,9 @@ class AwqQuantizer(HfQuantizer): if torch_dtype is None: torch_dtype = torch.float16 logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.") + elif torch_dtype == torch.bfloat16: + logger.warning("`torch.bfloat16` is not supported for AWQ kernels yet. Casting to `torch.float16`.") + torch_dtype = torch.float16 elif torch_dtype != torch.float16: logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.") return torch_dtype @@ -102,7 +105,7 @@ class AwqQuantizer(HfQuantizer): from ..integrations import replace_quantization_scales, replace_with_awq_linear self.modules_to_not_convert = self.get_modules_to_not_convert( - model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules + model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules, add_default_skips=True ) model, has_been_replaced = replace_with_awq_linear(