diff --git a/src/transformers/quantizers/quantizer_gptq.py b/src/transformers/quantizers/quantizer_gptq.py index 0fc35bea163..0aacfd55f36 100644 --- a/src/transformers/quantizers/quantizer_gptq.py +++ b/src/transformers/quantizers/quantizer_gptq.py @@ -22,7 +22,14 @@ from .base import HfQuantizer if TYPE_CHECKING: from ..modeling_utils import PreTrainedModel -from ..utils import is_auto_gptq_available, is_gptqmodel_available, is_optimum_available, is_torch_available, logging +from ..utils import ( + is_auto_gptq_available, + is_gptqmodel_available, + is_optimum_available, + is_torch_available, + is_torch_xpu_available, + logging, +) from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin @@ -89,7 +96,12 @@ class GptqHfQuantizer(HfQuantizer): def update_device_map(self, device_map): if device_map is None: - device_map = {"": torch.device("cpu")} + if torch.cuda.is_available(): + device_map = {"": torch.cuda.current_device()} + elif is_torch_xpu_available(): + device_map = {"": torch.xpu.current_device()} + else: + device_map = {"": "cpu"} # Only with auto-gptq do not support CPU, we should move the model to cuda if available. if not is_gptqmodel_available() and device_map in ("cpu", {"": torch.device("cpu")}): device_map == {"": 0}