mirror of
https://github.com/huggingface/transformers.git
synced 2025-08-02 03:01:07 +06:00
do not scale gradient in bf16 mode (#21428)
* no dot scale gradient in bf16 mode * fix since args.fp16 might be none * fixed typo * typo * only do if grad scaling is true * self.amp_dtype == torch.float16 is true * put back prop when fsdp is not none
This commit is contained in:
parent
197e7ce911
commit
fb13a7df95
@ -595,27 +595,26 @@ class Trainer:
|
|||||||
if args.half_precision_backend == "cuda_amp":
|
if args.half_precision_backend == "cuda_amp":
|
||||||
self.use_cuda_amp = True
|
self.use_cuda_amp = True
|
||||||
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
|
||||||
self.do_grad_scaling = True
|
# bf16 does not need grad scaling
|
||||||
if self.sharded_ddp is not None:
|
self.do_grad_scaling = self.amp_dtype == torch.float16
|
||||||
self.scaler = ShardedGradScaler()
|
if self.do_grad_scaling:
|
||||||
elif self.fsdp is not None:
|
if self.sharded_ddp is not None:
|
||||||
if self.amp_dtype == torch.float16:
|
self.scaler = ShardedGradScaler()
|
||||||
|
elif self.fsdp is not None:
|
||||||
from torch.distributed.fsdp.sharded_grad_scaler import (
|
from torch.distributed.fsdp.sharded_grad_scaler import (
|
||||||
ShardedGradScaler as FSDPShardedGradScaler,
|
ShardedGradScaler as FSDPShardedGradScaler,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.scaler = FSDPShardedGradScaler()
|
self.scaler = FSDPShardedGradScaler()
|
||||||
|
elif is_torch_tpu_available():
|
||||||
|
from torch_xla.amp import GradScaler
|
||||||
|
|
||||||
|
self.scaler = GradScaler()
|
||||||
else:
|
else:
|
||||||
self.do_grad_scaling = False
|
self.scaler = torch.cuda.amp.GradScaler()
|
||||||
self.use_cuda_amp = False
|
elif self.fsdp is not None:
|
||||||
self.amp_dtype = None
|
self.use_cuda_amp = False
|
||||||
|
self.amp_dtype = None
|
||||||
elif is_torch_tpu_available():
|
|
||||||
from torch_xla.amp import GradScaler
|
|
||||||
|
|
||||||
self.scaler = GradScaler()
|
|
||||||
else:
|
|
||||||
self.scaler = torch.cuda.amp.GradScaler()
|
|
||||||
elif args.half_precision_backend == "cpu_amp":
|
elif args.half_precision_backend == "cpu_amp":
|
||||||
self.use_cpu_amp = True
|
self.use_cpu_amp = True
|
||||||
self.amp_dtype = torch.bfloat16
|
self.amp_dtype = torch.bfloat16
|
||||||
@ -669,7 +668,7 @@ class Trainer:
|
|||||||
|
|
||||||
# torch.compile
|
# torch.compile
|
||||||
if args.torch_compile and not is_torch_compile_available():
|
if args.torch_compile and not is_torch_compile_available():
|
||||||
raise RuntimeError("Using torch.compile requires a nighly install of PyTorch.")
|
raise RuntimeError("Using torch.compile requires a nightly install of PyTorch.")
|
||||||
|
|
||||||
def add_callback(self, callback):
|
def add_callback(self, callback):
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user