diff --git a/src/transformers/optimization.py b/src/transformers/optimization.py index 3727784fba9..1ee2f41d2f9 100644 --- a/src/transformers/optimization.py +++ b/src/transformers/optimization.py @@ -444,9 +444,8 @@ def get_scheduler( def scheduler_hook(param): # Since the optimizer hook has been already attached we only need to - # attach the scheduler hook - if param.grad is not None: - scheduler_dict[param].step() + # attach the scheduler hook, the gradients have been zeroed here + scheduler_dict[param].step() for param in optimizer_dict.keys(): if param.requires_grad: