diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index bbb1f22350e..892f2e3b5bc 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1692,7 +1692,7 @@ class TrainingArguments: # cpu raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") elif not self.use_cpu: - if not is_torch_bf16_gpu_available(): + if not is_torch_bf16_gpu_available() and not is_torch_xla_available(): # added for tpu support error_message = "Your setup doesn't support bf16/gpu." if is_torch_cuda_available(): error_message += " You need Ampere+ GPU with cuda>=11.0"