diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 5a57fbfef5f..27652435f44 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -565,7 +565,7 @@ class Trainer: logger.info(f"Using {args.half_precision_backend} half precision backend") self.do_grad_scaling = False - if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()): + if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()): # deepspeed and SageMaker Model Parallel manage their own half precision if args.half_precision_backend == "cuda_amp": self.use_cuda_amp = True diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 4cb0523b404..b92afac1712 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1122,9 +1122,9 @@ class TrainingArguments: if self.bf16 or self.bf16_full_eval: - if self.no_cuda and not is_torch_bf16_cpu_available(): + if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available(): # cpu - raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10") + raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10") elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available(): # gpu raise ValueError( @@ -1172,12 +1172,13 @@ class TrainingArguments: and is_torch_available() and (self.device.type != "cuda") and (get_xla_device_type(self.device) != "GPU") + and (get_xla_device_type(self.device) != "TPU") and (self.device.type != "cpu") and (self.bf16 or self.bf16_full_eval) ): raise ValueError( "BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation" - " (`--bf16_full_eval`) can only be used on CUDA or CPU devices." + " (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices." ) if self.torchdynamo is not None: