Should check that torch TPU is available (#5636)

This commit is contained in:
Lysandre Debut 2020-07-09 13:54:32 -04:00 committed by GitHub
parent 3cc23eee06
commit b25f7802de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -34,6 +34,7 @@ from .file_utils import (
cached_path,
hf_bucket_url,
is_remote_url,
is_torch_tpu_available,
)
from .generation_utils import GenerationMixin
@ -794,7 +795,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
}
return model, loading_info
if hasattr(config, "xla_device") and config.xla_device:
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
import torch_xla.core.xla_model as xm
model = xm.send_cpu_data_to_device(model, xm.xla_device())