fix awq tests due to ipex backend (#34011)

fix awq tests
This commit is contained in:
Marc Sun 2024-10-08 15:56:05 +02:00 committed by GitHub
parent 4f2bf135af
commit 1909def2de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 16 additions and 5 deletions

View File

@ -13,9 +13,13 @@
# limitations under the License.
"AWQ (Activation aware Weight Quantization) integration file"
import importlib
from packaging import version
from ..activations import ACT2FN
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_awq_available, is_torch_available, logging
from ..utils import is_auto_awq_available, is_ipex_available, is_torch_available, logging
from ..utils.quantization_config import (
AwqBackendPackingMethod,
AwqConfig,
@ -379,7 +383,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
The `QuantAttentionFused` class as it only supports that class
for now.
"""
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV, WQLinear_IPEX
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
module_has_been_fused = False
@ -396,9 +400,12 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
elif isinstance(q_proj, WQLinear_GEMM):
linear_target_cls = WQLinear_GEMM
cat_dim = 1
elif isinstance(q_proj, WQLinear_IPEX):
linear_target_cls = WQLinear_IPEX
cat_dim = 1
elif is_ipex_available() and version.parse(importlib.metadata.version("autoawq")) > version.parse("0.2.6"):
from awq.modules.linear import WQLinear_IPEX
if isinstance(q_proj, WQLinear_IPEX):
linear_target_cls = WQLinear_IPEX
cat_dim = 1
else:
raise ValueError("Unsupported q_proj type: {type(q_proj)}")

View File

@ -53,6 +53,10 @@ class AwqQuantizer(HfQuantizer):
raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
if self.quantization_config.version == AWQLinearVersion.IPEX:
if version.parse(importlib.metadata.version("autoawq")) < version.parse("0.2.6"):
raise RuntimeError(
"To use IPEX backend, you need autoawq>0.6.2. Please install the latest version or from source."
)
if (
device_map is not None
and isinstance(device_map, dict)