mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
Restore fp16 support on xla gpu device (#22300)
This commit is contained in:
parent
67c2dbdb54
commit
d35f729649
@ -598,7 +598,7 @@ class Trainer:
|
||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||
|
||||
self.do_grad_scaling = False
|
||||
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
|
||||
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
|
||||
# deepspeed and SageMaker Model Parallel manage their own half precision
|
||||
if args.half_precision_backend == "cuda_amp":
|
||||
self.use_cuda_amp = True
|
||||
|
Loading…
Reference in New Issue
Block a user