mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-03 12:50:06 +06:00
fix bug when using gptq model on xpu device
Signed-off-by: Liu, Kaixuan <kaixuan.liu@intel.com>
This commit is contained in:
parent
dbc98328da
commit
dcb694c6c3
@ -22,7 +22,14 @@ from .base import HfQuantizer
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..modeling_utils import PreTrainedModel
|
from ..modeling_utils import PreTrainedModel
|
||||||
|
|
||||||
from ..utils import is_auto_gptq_available, is_gptqmodel_available, is_optimum_available, is_torch_available, logging
|
from ..utils import (
|
||||||
|
is_auto_gptq_available,
|
||||||
|
is_gptqmodel_available,
|
||||||
|
is_optimum_available,
|
||||||
|
is_torch_available,
|
||||||
|
is_torch_xpu_available,
|
||||||
|
logging,
|
||||||
|
)
|
||||||
from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin
|
from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin
|
||||||
|
|
||||||
|
|
||||||
@ -89,7 +96,12 @@ class GptqHfQuantizer(HfQuantizer):
|
|||||||
|
|
||||||
def update_device_map(self, device_map):
|
def update_device_map(self, device_map):
|
||||||
if device_map is None:
|
if device_map is None:
|
||||||
device_map = {"": torch.device("cpu")}
|
if torch.cuda.is_available():
|
||||||
|
device_map = {"": torch.cuda.current_device()}
|
||||||
|
elif is_torch_xpu_available():
|
||||||
|
device_map = {"": torch.xpu.current_device()}
|
||||||
|
else:
|
||||||
|
device_map = {"": "cpu"}
|
||||||
# Only with auto-gptq do not support CPU, we should move the model to cuda if available.
|
# Only with auto-gptq do not support CPU, we should move the model to cuda if available.
|
||||||
if not is_gptqmodel_available() and device_map in ("cpu", {"": torch.device("cpu")}):
|
if not is_gptqmodel_available() and device_map in ("cpu", {"": torch.device("cpu")}):
|
||||||
device_map == {"": 0}
|
device_map == {"": 0}
|
||||||
|
Loading…
Reference in New Issue
Block a user