mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-26 07:49:01 +06:00
Enable bf16 option for XLA devices (#20684)
This commit is contained in:
parent
9858ecd706
commit
bcc069ddb8
@ -565,7 +565,7 @@ class Trainer:
|
|||||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||||
|
|
||||||
self.do_grad_scaling = False
|
self.do_grad_scaling = False
|
||||||
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
|
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
|
||||||
# deepspeed and SageMaker Model Parallel manage their own half precision
|
# deepspeed and SageMaker Model Parallel manage their own half precision
|
||||||
if args.half_precision_backend == "cuda_amp":
|
if args.half_precision_backend == "cuda_amp":
|
||||||
self.use_cuda_amp = True
|
self.use_cuda_amp = True
|
||||||
|
@ -1122,9 +1122,9 @@ class TrainingArguments:
|
|||||||
|
|
||||||
if self.bf16 or self.bf16_full_eval:
|
if self.bf16 or self.bf16_full_eval:
|
||||||
|
|
||||||
if self.no_cuda and not is_torch_bf16_cpu_available():
|
if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
|
||||||
# cpu
|
# cpu
|
||||||
raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10")
|
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
|
||||||
elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
|
elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
|
||||||
# gpu
|
# gpu
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -1172,12 +1172,13 @@ class TrainingArguments:
|
|||||||
and is_torch_available()
|
and is_torch_available()
|
||||||
and (self.device.type != "cuda")
|
and (self.device.type != "cuda")
|
||||||
and (get_xla_device_type(self.device) != "GPU")
|
and (get_xla_device_type(self.device) != "GPU")
|
||||||
|
and (get_xla_device_type(self.device) != "TPU")
|
||||||
and (self.device.type != "cpu")
|
and (self.device.type != "cpu")
|
||||||
and (self.bf16 or self.bf16_full_eval)
|
and (self.bf16 or self.bf16_full_eval)
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
|
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
|
||||||
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
|
" (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.torchdynamo is not None:
|
if self.torchdynamo is not None:
|
||||||
|
Loading…
Reference in New Issue
Block a user