Restore fp16 support on xla gpu device (#22300)

This commit is contained in:
Yanming W 2023-03-21 13:32:43 -07:00 committed by GitHub
parent 67c2dbdb54
commit d35f729649
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -598,7 +598,7 @@ class Trainer:
logger.info(f"Using {args.half_precision_backend} half precision backend")
self.do_grad_scaling = False
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
# deepspeed and SageMaker Model Parallel manage their own half precision
if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True