mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Fix some failing AWQ tests (#37383)
* update AwqQuantizer * fix style * add an arg to get_modules_to_not_convert to add get_keys_to_not_convert(model)
This commit is contained in:
parent
71b35387fd
commit
c5c648dd74
@ -263,14 +263,17 @@ class HfQuantizer(ABC):
|
||||
model: "PreTrainedModel",
|
||||
skip_modules: Optional[List[str]] = None,
|
||||
keep_in_fp32_modules: Optional[List[str]] = None,
|
||||
add_default_skips: bool = False,
|
||||
):
|
||||
from ..integrations import get_keys_to_not_convert
|
||||
|
||||
modules_to_not_convert = []
|
||||
if skip_modules is None:
|
||||
if skip_modules is None or add_default_skips:
|
||||
modules_to_not_convert = get_keys_to_not_convert(model)
|
||||
else:
|
||||
modules_to_not_convert = skip_modules
|
||||
modules_to_not_convert = []
|
||||
|
||||
if skip_modules is not None:
|
||||
modules_to_not_convert.extend(skip_modules)
|
||||
|
||||
if keep_in_fp32_modules is not None:
|
||||
modules_to_not_convert.extend(keep_in_fp32_modules)
|
||||
|
@ -92,6 +92,9 @@ class AwqQuantizer(HfQuantizer):
|
||||
if torch_dtype is None:
|
||||
torch_dtype = torch.float16
|
||||
logger.info("Loading the model in `torch.float16`. To overwrite it, set `torch_dtype` manually.")
|
||||
elif torch_dtype == torch.bfloat16:
|
||||
logger.warning("`torch.bfloat16` is not supported for AWQ kernels yet. Casting to `torch.float16`.")
|
||||
torch_dtype = torch.float16
|
||||
elif torch_dtype != torch.float16:
|
||||
logger.warning("We suggest you to set `torch_dtype=torch.float16` for better efficiency with AWQ.")
|
||||
return torch_dtype
|
||||
@ -102,7 +105,7 @@ class AwqQuantizer(HfQuantizer):
|
||||
from ..integrations import replace_quantization_scales, replace_with_awq_linear
|
||||
|
||||
self.modules_to_not_convert = self.get_modules_to_not_convert(
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
|
||||
model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules, add_default_skips=True
|
||||
)
|
||||
|
||||
model, has_been_replaced = replace_with_awq_linear(
|
||||
|
Loading…
Reference in New Issue
Block a user