Gaudi: Fix the pipeline failed issue with hpu device (#36990)

* Gaudi: fix the issue of is_torch_hpu_available() returns false

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Fix make fixup

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Add comments for the implicit behavior of import

Signed-off-by: yuanwu <yuan.wu@intel.com>

* Update src/transformers/utils/import_utils.py

* Update src/transformers/utils/import_utils.py

---------

Signed-off-by: yuanwu <yuan.wu@intel.com>
Co-authored-by: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com>
This commit is contained in:
Yuan Wu 2025-03-31 16:23:47 +08:00 committed by GitHub
parent 6acd5aecb3
commit bd41b9c1ac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 12 additions and 2 deletions

View File

@ -947,12 +947,18 @@ class Pipeline(_ScikitCompat, PushToHubMixin):
if device == -1 and self.model.device is not None:
device = self.model.device
if isinstance(device, torch.device):
if device.type == "xpu" and not is_torch_xpu_available(check_device=True):
if (device.type == "xpu" and not is_torch_xpu_available(check_device=True)) or (
device.type == "hpu" and not is_torch_hpu_available()
):
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
self.device = device
elif isinstance(device, str):
if "xpu" in device and not is_torch_xpu_available(check_device=True):
if ("xpu" in device and not is_torch_xpu_available(check_device=True)) or (
"hpu" in device and not is_torch_hpu_available()
):
raise ValueError(f'{device} is not available, you should use device="cpu" instead')
self.device = torch.device(device)
elif device < 0:
self.device = torch.device("cpu")

View File

@ -811,6 +811,10 @@ def is_torch_hpu_available():
import torch
if os.environ.get("PT_HPU_LAZY_MODE", "1") == "1":
# import habana_frameworks.torch in case of lazy mode to patch torch with torch.hpu
import habana_frameworks.torch # noqa: F401
if not hasattr(torch, "hpu") or not torch.hpu.is_available():
return False