mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-23 22:38:58 +06:00
Enable bf16 option for XLA devices (#20684)
This commit is contained in:
parent
9858ecd706
commit
bcc069ddb8
@ -565,7 +565,7 @@ class Trainer:
|
||||
logger.info(f"Using {args.half_precision_backend} half precision backend")
|
||||
|
||||
self.do_grad_scaling = False
|
||||
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled()):
|
||||
if (args.fp16 or args.bf16) and not (args.deepspeed or is_sagemaker_mp_enabled() or is_torch_tpu_available()):
|
||||
# deepspeed and SageMaker Model Parallel manage their own half precision
|
||||
if args.half_precision_backend == "cuda_amp":
|
||||
self.use_cuda_amp = True
|
||||
|
@ -1122,9 +1122,9 @@ class TrainingArguments:
|
||||
|
||||
if self.bf16 or self.bf16_full_eval:
|
||||
|
||||
if self.no_cuda and not is_torch_bf16_cpu_available():
|
||||
if self.no_cuda and not is_torch_bf16_cpu_available() and not is_torch_tpu_available():
|
||||
# cpu
|
||||
raise ValueError("Your setup doesn't support bf16/cpu. You need torch>=1.10")
|
||||
raise ValueError("Your setup doesn't support bf16/(cpu, tpu, neuroncore). You need torch>=1.10")
|
||||
elif not self.no_cuda and torch.cuda.is_available() and not is_torch_bf16_gpu_available():
|
||||
# gpu
|
||||
raise ValueError(
|
||||
@ -1172,12 +1172,13 @@ class TrainingArguments:
|
||||
and is_torch_available()
|
||||
and (self.device.type != "cuda")
|
||||
and (get_xla_device_type(self.device) != "GPU")
|
||||
and (get_xla_device_type(self.device) != "TPU")
|
||||
and (self.device.type != "cpu")
|
||||
and (self.bf16 or self.bf16_full_eval)
|
||||
):
|
||||
raise ValueError(
|
||||
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
|
||||
" (`--bf16_full_eval`) can only be used on CUDA or CPU devices."
|
||||
" (`--bf16_full_eval`) can only be used on CUDA or CPU/TPU/NeuronCore devices."
|
||||
)
|
||||
|
||||
if self.torchdynamo is not None:
|
||||
|
Loading…
Reference in New Issue
Block a user