Enable bf16 option for XLA devices (#20684)

This commit is contained in:
jeffhataws 2022-12-08 09:34:40 -08:00 committed by GitHub
parent 9858ecd706
commit bcc069ddb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 4 deletions

View File

@ -565,7 +565,7 @@ class Trainer:
logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
# deepspeed and SageMaker Model Parallel manage their own half precision
if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True

View File

@ -1122,9 +1122,9 @@ class TrainingArguments:
if self.bf16 or self.bf16_full_eval:
if self.no_cuda and not is_torch_bf16_cpu_available():
if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
# cpu
raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10")
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
# gpu
raise ValueError(
@ -1172,12 +1172,13 @@ class TrainingArguments:
and is_torch_available()
and (self.device.type != "cuda")
and (get_xla_device_type(self.device) != "GPU")
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
" (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices."
)
if self.torchdynamo is not None: