mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Should check that torch TPU is available (#5636)
This commit is contained in:
parent
3cc23eee06
commit
b25f7802de
@ -34,6 +34,7 @@ from .file_utils import (
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_remote_url,
|
||||
is_torch_tpu_available,
|
||||
)
|
||||
from .generation_utils import GenerationMixin
|
||||
|
||||
@ -794,7 +795,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
if hasattr(config, "xla_device") and config.xla_device:
|
||||
if hasattr(config, "xla_device") and config.xla_device and is_torch_tpu_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
model = xm.send_cpu_data_to_device(model, xm.xla_device())
|
||||
|
Loading…
Reference in New Issue
Block a user