This commit is contained in:
kaixuanliu 2025-07-02 23:32:46 +02:00 committed by GitHub
commit d04a906159
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -22,7 +22,14 @@ from .base import HfQuantizer
if TYPE_CHECKING:
from ..modeling_utils import PreTrainedModel
from ..utils import is_auto_gptq_available, is_gptqmodel_available, is_optimum_available, is_torch_available, logging
from ..utils import (
is_auto_gptq_available,
is_gptqmodel_available,
is_optimum_available,
is_torch_available,
is_torch_xpu_available,
logging,
)
from ..utils.quantization_config import GPTQConfig, QuantizationConfigMixin
@ -89,7 +96,12 @@ class GptqHfQuantizer(HfQuantizer):
def update_device_map(self, device_map):
if device_map is None:
device_map = {"": torch.device("cpu")}
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
elif is_torch_xpu_available():
device_map = {"": torch.xpu.current_device()}
else:
device_map = {"": "cpu"}
# Only with auto-gptq do not support CPU, we should move the model to cuda if available.
if not is_gptqmodel_available() and device_map in ("cpu", {"": torch.device("cpu")}):
device_map == {"": 0}