Fix some failing AWQ tests (#37383)

* update AwqQuantizer

* fix style

* add an arg to get_modules_to_not_convert to add get_keys_to_not_convert(model)
This commit is contained in:
DerekLiu35 2025-04-09 12:24:57 -04:00 committed by GitHub
parent 71b35387fd
commit c5c648dd74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 4 deletions

View File

@ -263,14 +263,17 @@ class HfQuantizer(ABC):
model: "PreTrainedModel",
skip_modules: Optional[List[str]] = None,
keep_in_fp32_modules: Optional[List[str]] = None,
add_default_skips: bool = False,
):
from ..integrations import get_keys_to_not_convert
modules_to_not_convert = []
if skip_modules is None:
if skip_modules is None or add_default_skips:
modules_to_not_convert = get_keys_to_not_convert(model)
else:
modules_to_not_convert = skip_modules
modules_to_not_convert = []
if skip_modules is not None:
modules_to_not_convert.extend(skip_modules)
if keep_in_fp32_modules is not None:
modules_to_not_convert.extend(keep_in_fp32_modules)

View File

@ -92,6 +92,9 @@ class AwqQuantizer(HfQuantizer):
if torch_dtype is None:
torch_dtype = torch.float16
logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.")
elif torch_dtype == torch.bfloat16:
logger.warning("`torch.bfloat16` is not supported for AWQ kernels yet. Casting to `torch.float16`.")
torch_dtype = torch.float16
elif torch_dtype != torch.float16:
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
return torch_dtype
@ -102,7 +105,7 @@ class AwqQuantizer(HfQuantizer):
from ..integrations import replace_quantization_scales, replace_with_awq_linear
self.modules_to_not_convert = self.get_modules_to_not_convert(
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules, add_default_skips=True
)
model, has_been_replaced = replace_with_awq_linear(