do not scale gradient in bf16 mode (#21428)

* no dot scale gradient in bf16 mode

* fix since args.fp16 might be none

* fixed typo

* typo

* only do if grad scaling is true

* self.amp_dtype == torch.float16 is true

* put back prop when fsdp is not none
This commit is contained in:
Kashif Rasul 2023-02-03 17:57:33 +01:00 committed by GitHub
parent 197e7ce911
commit fb13a7df95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -595,27 +595,26 @@ class Trainer:
if args.half_precision_backend == "cuda_amp":
self.use_cuda_amp = True
self.amp_dtype = torch.float16 if args.fp16 else torch.bfloat16
self.do_grad_scaling = True
if self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
elif self.fsdp is not None:
if self.amp_dtype == torch.float16:
# bf16 does not need grad scaling
self.do_grad_scaling = self.amp_dtype == torch.float16
if self.do_grad_scaling:
if self.sharded_ddp is not None:
self.scaler = ShardedGradScaler()
elif self.fsdp is not None:
from torch.distributed.fsdp.sharded_grad_scaler import (
ShardedGradScaler as FSDPShardedGradScaler,
)
self.scaler = FSDPShardedGradScaler()
elif is_torch_tpu_available():
from torch_xla.amp import GradScaler
self.scaler = GradScaler()
else:
self.do_grad_scaling = False
self.use_cuda_amp = False
self.amp_dtype = None
elif is_torch_tpu_available():
from torch_xla.amp import GradScaler
self.scaler = GradScaler()
else:
self.scaler = torch.cuda.amp.GradScaler()
self.scaler = torch.cuda.amp.GradScaler()
elif self.fsdp is not None:
self.use_cuda_amp = False
self.amp_dtype = None
elif args.half_precision_backend == "cpu_amp":
self.use_cpu_amp = True
self.amp_dtype = torch.bfloat16
@ -669,7 +668,7 @@ class Trainer:
# torch.compile
if args.torch_compile and not is_torch_compile_available():
raise RuntimeError("Using torch.compile requires a nighly install of PyTorch.")
raise RuntimeError("Using torch.compile requires a nightly install of PyTorch.")
def add_callback(self, callback):
"""