Properly check for a TPU device (#17802)

This commit is contained in:
Zachary Mueller 2022-06-21 13:39:55 -04:00 committed by GitHub
parent ef23fae596
commit 52404cbad4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -399,12 +399,16 @@ def is_ftfy_available():
def is_torch_tpu_available():
if not _torch_available:
return False
# This test is probably enough, but just in case, we unpack a bit.
if importlib.util.find_spec("torch_xla") is None:
return False
if importlib.util.find_spec("torch_xla.core") is None:
import torch_xla.core.xla_model as xm
# We need to check if `xla_device` can be found, will raise a RuntimeError if not
try:
xm.xla_device()
return True
except RuntimeError:
return False
return importlib.util.find_spec("torch_xla.core.xla_model") is not None
def is_torchdynamo_available():