mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 19:21:31 +06:00
parent
4f2bf135af
commit
1909def2de
@ -13,9 +13,13 @@
|
||||
# limitations under the License.
|
||||
"AWQ (Activation aware Weight Quantization) integration file"
|
||||
|
||||
import importlib
|
||||
|
||||
from packaging import version
|
||||
|
||||
from ..activations import ACT2FN
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from ..utils import is_auto_awq_available, is_torch_available, logging
|
||||
from ..utils import is_auto_awq_available, is_ipex_available, is_torch_available, logging
|
||||
from ..utils.quantization_config import (
|
||||
AwqBackendPackingMethod,
|
||||
AwqConfig,
|
||||
@ -379,7 +383,7 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
|
||||
The `QuantAttentionFused` class as it only supports that class
|
||||
for now.
|
||||
"""
|
||||
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV, WQLinear_IPEX
|
||||
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
|
||||
|
||||
module_has_been_fused = False
|
||||
|
||||
@ -396,9 +400,12 @@ def _fuse_awq_attention_layers(model, module, modules_to_fuse, current_module_na
|
||||
elif isinstance(q_proj, WQLinear_GEMM):
|
||||
linear_target_cls = WQLinear_GEMM
|
||||
cat_dim = 1
|
||||
elif isinstance(q_proj, WQLinear_IPEX):
|
||||
linear_target_cls = WQLinear_IPEX
|
||||
cat_dim = 1
|
||||
elif is_ipex_available() and version.parse(importlib.metadata.version("autoawq")) > version.parse("0.2.6"):
|
||||
from awq.modules.linear import WQLinear_IPEX
|
||||
|
||||
if isinstance(q_proj, WQLinear_IPEX):
|
||||
linear_target_cls = WQLinear_IPEX
|
||||
cat_dim = 1
|
||||
else:
|
||||
raise ValueError("Unsupported q_proj type: {type(q_proj)}")
|
||||
|
||||
|
@ -53,6 +53,10 @@ class AwqQuantizer(HfQuantizer):
|
||||
raise ImportError("Loading an AWQ quantized model requires accelerate (`pip install accelerate`)")
|
||||
|
||||
if self.quantization_config.version == AWQLinearVersion.IPEX:
|
||||
if version.parse(importlib.metadata.version("autoawq")) < version.parse("0.2.6"):
|
||||
raise RuntimeError(
|
||||
"To use IPEX backend, you need autoawq>0.6.2. Please install the latest version or from source."
|
||||
)
|
||||
if (
|
||||
device_map is not None
|
||||
and isinstance(device_map, dict)
|
||||
|
Loading…
Reference in New Issue
Block a user