mirror of
https://github.com/huggingface/transformers.git
synced 2025-07-31 02:02:21 +06:00
do not scale gradient in bf16 mode (#21428)
* no dot scale gradient in bf16 mode * fix since args.fp16 might be none * fixed typo * typo * only do if grad scaling is true * self.amp_dtype == torch.float16 is true * put back prop when fsdp is not none
This commit is contained in:
parent
197e7ce911
commit
fb13a7df95
@ -595,27 +595,26 @@ class Trainer:
|
||||
if args.half_precision_backend == "cuda_amp":
|
||||
self.use_cuda_amp = True
|
||||
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
||||
self.do_grad_scaling = True
|
||||
if self.sharded_ddp is not None:
|
||||
self.scaler = ShardedGradScaler()
|
||||
elif self.fsdp is not None:
|
||||
if self.amp_dtype == torch.float16:
|
||||
# bf16 does not need grad scaling
|
||||
self.do_grad_scaling = self.amp_dtype == torch.float16
|
||||
if self.do_grad_scaling:
|
||||
if self.sharded_ddp is not None:
|
||||
self.scaler = ShardedGradScaler()
|
||||
elif self.fsdp is not None:
|
||||
from torch.distributed.fsdp.sharded_grad_scaler import (
|
||||
ShardedGradScaler as FSDPShardedGradScaler,
|
||||
)
|
||||
|
||||
self.scaler = FSDPShardedGradScaler()
|
||||
elif is_torch_tpu_available():
|
||||
from torch_xla.amp import GradScaler
|
||||
|
||||
self.scaler = GradScaler()
|
||||
else:
|
||||
self.do_grad_scaling = False
|
||||
self.use_cuda_amp = False
|
||||
self.amp_dtype = None
|
||||
|
||||
elif is_torch_tpu_available():
|
||||
from torch_xla.amp import GradScaler
|
||||
|
||||
self.scaler = GradScaler()
|
||||
else:
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
self.scaler = torch.cuda.amp.GradScaler()
|
||||
elif self.fsdp is not None:
|
||||
self.use_cuda_amp = False
|
||||
self.amp_dtype = None
|
||||
elif args.half_precision_backend == "cpu_amp":
|
||||
self.use_cpu_amp = True
|
||||
self.amp_dtype = torch.bfloat16
|
||||
@ -669,7 +668,7 @@ class Trainer:
|
||||
|
||||
# torch.compile
|
||||
if args.torch_compile and not is_torch_compile_available():
|
||||
raise RuntimeError("Using torch.compile requires a nighly install of PyTorch.")
|
||||
raise RuntimeError("Using torch.compile requires a nightly install of PyTorch.")
|
||||
|
||||
def add_callback(self, callback):
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user